"""
===============
Spline Variable
===============
A subclass of :class:`anml.parameter.variables.Variable` that handles spline related computations.
"""
from dataclasses import dataclass, field
from typing import List, Optional, Union
import pandas as pd
import numpy as np
from xspline import XSpline
from anml.parameter.variables import Variable, VariableError
from anml.parameter.prior import Prior
[docs]@dataclass
class SplineLinearConstr:
"""Constraints on spline derivatives. The general form is
lb <= Ax <= ub
where x is in some interval domain `x_domain`, and A can be 0th, 1st or 2nd order derivative matrix
of the splines evaluated at some discretization points.
A is not known at the initialization of this object, but will have dimension
`grid_size` by `number of spline basis`.
`lb` and `ub` are vectors of multiples of ones.
This type of constraints can be used to impose monotonicity and convexity constraints.
For instance, for splines defined on `[0, 5]`, one can specify monotonically decreasing on `[0,1]` with
`constr = SplineLinearConstr(x_domain=[0, 1], y_bounds=[-np.inf, 0.0], order=1)`,
monotonically increasing on `[4, 5]` with
`constr = SplineLinearConstr(x_domain=[4, 5], y_bounds=[0.0, np.inf], order=1)`,
and overall convexity with
`constr = SplineLinearConstr(x_domain=[0, 5], y_bounds=[0.0, np.inf], order=2)`.
Parameters
----------
order: int
order of the derivative
y_bounds: List[float, float]
bounds for y = Ax
x_domain: List[float, float], optional
domain for x, default to be -inf to inf
grid_size: int, optional
size of grid
Raises
------
ValueError
domain for x is not valid.
ValueError
bounds for y = Ax is not valid.
ValueError
invalid derivative order
ValueError
invalid grid size
"""
order: int
y_bounds: List[float]
x_domain: List[float] = field(default_factory=lambda: [-np.inf, np.inf])
grid_size: int = None
def __post_init__(self):
if self.x_domain[0] >= self.x_domain[1]:
raise ValueError('Domain must have positive length.')
if self.y_bounds[0] > self.y_bounds[1]:
raise ValueError('Lower bound cannot be greater than upper bound.')
if self.order < 0:
raise ValueError('Order of derivative must be nonnegative.')
if self.grid_size is not None and self.grid_size < 1:
raise ValueError('Grid size must be at least 1.')
[docs]@dataclass
class Spline(Variable):
"""Spline variable.
Parameters
----------
knots_type : str
type of knots. can only be 'frequency' or 'domain'
knots_num: int
number of knots
degree: int
degree of spines
l_linear: bool
whether left tail is linear
r_linear: bool
whether right tail is linear
include_intercept: bool
whether to include intercept in design matrix
derivative_constr: List[`~anml.parameter.spline_variable.SplineLinearConstr`]
constraints on derivatives
constr_grid_size_global: int, optional
number of points to use when building constraint matrix. used only when `grid_size` for
individual `~anml.parameter.spline_variable.SplineLinearConstr` is not available
Raises
------
VariableError
unknown knot type
VariableError
no covariate has been set
"""
fe_prior: Optional[Prior] = field(init=False)
add_re: bool = field(init=False)
knots_type: str = 'frequency'
knots_num: int = 3
degree: int = 3
l_linear: bool = False
r_linear: bool = False
include_intercept: bool = False
derivative_constr: List[SplineLinearConstr] = field(default_factory=lambda: [])
constr_grid_size_global: int = None
spline: Optional[XSpline] = field(init=False)
x: Optional[np.ndarray] = field(init=False)
constr_matrix_fe: Optional[np.ndarray] = field(init=False)
constr_lb_fe: Optional[Union[List[float], np.ndarray]] = field(init=False)
constr_ub_fe: Optional[Union[List[float], np.ndarray]] = field(init=False)
def __post_init__(self):
if self.knots_type not in ['frequency', 'domain']:
raise VariableError(f"Unknown knots_type for Spline {self.knots_type}.")
self.spline = None
self.add_re = False
self.fe_prior = None
Variable.__post_init__(self)
if self.fe_prior is None:
self.set_fe_prior(
Prior(lower_bound=[-np.inf] * self._count_num_fe(),
upper_bound=[np.inf] * self._count_num_fe())
)
def _count_num_fe(self):
return self.knots_num - self.l_linear - self.r_linear + self.degree - 1 - int(not self.include_intercept)
[docs] def set_fe_prior(self, prior: Prior):
if prior.x_dim != self.num_fe:
raise ValueError(f'Dimension of fe_prior = {prior.x_dim} should match num_fe = {self.num_fe}.')
self.fe_prior = prior
[docs] def create_spline(self, df: pd.DataFrame):
if self.covariate is None:
raise VariableError("No covariate has been set.")
self.x = df[self.covariate].values
spline_knots = np.linspace(0, 1, self.knots_num)
if self.knots_type == 'frequency':
knots = np.quantile(self.x, spline_knots)
elif self.knots_type == 'domain':
knots = np.min(self.x) + spline_knots * (np.max(self.x) - np.min(self.x))
else:
raise VariableError(f"Unknown knots_type for Spline {self.knots_type}.")
self.spline = XSpline(
knots=knots,
degree=self.degree,
l_linear=self.l_linear,
r_linear=self.r_linear
)
def _design_matrix(self, df: pd.DataFrame, create_spline: bool = True) -> np.ndarray:
if create_spline:
self.create_spline(df)
if self.include_intercept:
return self.spline.design_mat(self.x)
else:
return self.spline.design_mat(self.x)[:, 1:]
[docs] def build_constraint_matrix_fe(self):
"""build constrain matrix and bounds for
`constr_lb` <= `constr_matrix` <= `constr_ub`.
Returns
-------
List[np.ndarray]
constraint matrix, lower bounds and upper bounds.
"""
lb, ub = min(self.x), max(self.x)
constr_matrices = []
constr_lbs = []
constr_ubs = []
for constr in self.derivative_constr:
if constr.x_domain[0] >= ub or constr.x_domain[1] <= lb:
raise ValueError(
f'Domain of constraint = {constr.x_domain} does not'
f' overlap with domain of spline. lb = {lb}, ub = {ub}.'
)
if constr.grid_size is None and self.constr_grid_size_global is None:
raise ValueError('Either global or individual constraint grid size needs to be specified.')
if constr.grid_size is not None:
points = np.linspace(max(lb, constr.x_domain[0]), min(ub, constr.x_domain[1]), constr.grid_size)
else:
points_all = np.linspace(lb, ub, self.constr_grid_size_global)
is_in_domain = constr.x_domain[0] <= points_all <= constr.x_domain[1]
points = points_all[is_in_domain]
n_points = len(points)
if self.include_intercept:
constr_matrices.append(self.spline.design_dmat(points, constr.order))
else:
constr_matrices.append(self.spline.design_dmat(points, constr.order)[:, 1:])
constr_lbs.append([constr.y_bounds[0]] * n_points)
constr_ubs.append([constr.y_bounds[1]] * n_points)
if len(constr_matrices) > 0:
self.constr_matrix_fe = np.vstack(constr_matrices)
self.constr_lb_fe = np.hstack(constr_lbs)
self.constr_ub_fe = np.hstack(constr_ubs)
else:
self.constr_matrix_fe = np.zeros((1, self.num_fe))
self.constr_lb_fe = [0.0]
self.constr_ub_fe = [0.0]