"""
=================
Solvers Interface
=================
"""
from typing import Optional, Dict, Any, List, Union
import numpy as np
from anml.data.data import Data
from anml.models.interface import Model
from anml.exceptions import ANMLError
[docs]class ModelNotDefinedError(ANMLError):
pass
[docs]class SolverNotDefinedError(ANMLError):
pass
[docs]class Solver:
def __init__(self, model_instance: Optional[Model] = None):
self._model = model_instance
self.x_opt = None
self.fun_val_opt = None
@property
def model(self):
return self._model
@model.setter
def model(self, model_instance: Model):
self._model = model_instance
[docs] def assert_model_defined(self):
if self._model is None:
raise ModelNotDefinedError()
[docs] def fit(self, x_init: Optional[np.ndarray] = None, data: Optional[Data] = None, options: Optional[Dict[str, Any]] = None, **kwargs):
raise NotImplementedError()
[docs] def predict(self, **kwargs):
return self.model.forward(self.x_opt, **kwargs)
[docs]class CompositeSolver(Solver):
def __init__(self, solvers_list: Optional[List[Solver]] = None):
super().__init__(model_instance=None)
if solvers_list is not None:
self._solvers = solvers_list
else:
self._solvers = []
@property
def solvers(self):
return self._solvers
@solvers.setter
def solvers(self, solvers_list: List[Solver]):
self._solvers = solvers_list
[docs] def add_solver(self, solver: Solver):
self._solvers.append(solver)
@property
def model(self):
models = []
self.assert_solvers_defined()
for solver in self._solvers:
models.append(solver.model)
return models
@model.setter
def model(self, model_instances: Union[Model, List[Model]]):
self.assert_solvers_defined()
if isinstance(model_instances, list):
if len(model_instances) != len(self._solvers):
raise ValueError(
'When passing in multiple models its length should equal to the number of solvers passed in.'
)
for model, solver in zip(model_instances, self._solvers):
solver.model = model
else:
for solver in self._solvers:
solver.model = model_instances
[docs] def assert_model_defined(self):
self.assert_solvers_defined()
for solver in self._solvers:
solver.assert_model_defined()
[docs] def assert_solvers_defined(self):
if len(self._solvers) == 0:
raise SolverNotDefinedError()