Source code for anml.solvers.composite

from typing import Callable, Optional, Dict, Any
"""
=================
Composite Solvers
=================

Composite solvers for optimization. Composition or decorator of solvers.
"""

import numpy as np

from anml.data.data import Data
from anml.solvers.interface import Solver, CompositeSolver


[docs]class MultipleInitializations(CompositeSolver): """Solver with multiple initialization """ def __init__(self, sample_fun: Callable, solver: Optional[Solver] = None): super().__init__() self.sample_fun = sample_fun
[docs] def fit(self, x_init: Optional[np.ndarray] = None, data: Optional[Data] = None, options: Optional[Dict[str, Any]] = None): self.assert_solvers_defined() if len(self.solvers) > 1: raise RuntimeError('Only implemented for single solver.') xs_init = self.sample_fun(x_init) fun_vals = [] xs_opt = [] for x in xs_init: self.solvers[0].fit(data=data, x_init=x, options=options) fun_vals.append(self.solvers[0].fun_val_opt) xs_opt.append(self.solvers[0].x_opt) self.x_opt = xs_opt[np.argmin(fun_vals)] self.fun_val_opt = np.min(fun_vals)
[docs] def predict(self, **kwargs): return self.solvers[0].predict(self.x_opt, **kwargs)