Source code for anml.models.interface

"""
================
Model Interface 
================

An interface for models.
"""
from abc import ABC
from typing import Optional

import numpy as np

from anml.data.data import Data


[docs]class Model: """Interface for models. """ def __init__(self): pass
[docs] def objective(self, x: np.ndarray, data: Data): """Objective function for a model. Parameters ---------- x : np.ndarray input vector data : Data a :class`~anml.data.data.Data` object Raises ------ NotImplementedError not implemented in this interface. """ raise NotImplementedError()
[docs] def gradient(self, x: np.ndarray, data: Data): """Gradient of objective function computed using complex step method. Can be overwritten in inherited classes. Parameters ---------- x : np.ndarray inpute vector data : Data a :class`~anml.data.data.Data` object. Returns ------- np.ndarray gradient vector """ step = 1e-16 x_c = x + 0j grad = np.zeros(x.size) for i in range(x.size): x_c[i] += step*1j grad[i] = self.objective(x_c, data).imag/step x_c[i] -= step*1j return grad
[docs] def forward(self, x: np.ndarray, *args, **kwargs): """Compute an output based on the generating mechanism defined by the model. Parameters ---------- x : np.ndarray input vector Raises ------ NotImplementedError not implemented in interface """ # different from predict() in solver in the sense that both variable and data value can vary. # in predict() variable value is at taken to be the optimal. raise NotImplementedError()
[docs]class TrimmingCompatibleModel(Model, ABC): def __init__(self): super().__init__() def _gradient(self, x: np.ndarray, data: Data) -> np.ndarray: """ This returns the gradient function by data point. So it's an array :param x: :param data: :return: """
[docs] def gradient(self, x: np.ndarray, data: Data) -> np.ndarray: """Gradient of objective function computed using complex step method. Can be overwritten in inherited classes. Parameters ---------- x : np.ndarray inpute vector data : Data a :class`~anml.data.data.Data` object. Returns ------- np.ndarray gradient vector """
def _objective(self, x: np.ndarray, data: Data) -> np.ndarray: """ This returns the objective function by data point. :param x: :param data: :return: """
[docs] def objective(self, x: np.ndarray, data: Data, w: Optional[np.ndarray] = None) -> float: """Objective function for a model. This objective function Parameters ---------- x : np.ndarray input vector data : Data a :class`~anml.data.data.Data` object w : An optional weights vector Raises ------ NotImplementedError not implemented in this interface. """ raise NotImplementedError()