Source code for anml.data.data

"""
===============
Data Management
===============

Data is managed and processed using :class:`~anml.data.data.Data`
with specifications provided through one or more
instances of :class:`~anml.data.data_specs.DataSpecs`.
"""

from collections import defaultdict
from typing import Union, List, Optional, Dict, Any

import numpy as np
import pandas as pd

from anml.data.data_specs import DataSpecs, _check_compatible_specs
from anml.exceptions import ANMLError
from anml.parameter.parameter import ParameterSet


[docs]class DataError(ANMLError): """Base error for the data module.""" pass
[docs]class DataTypeError(DataError): """Error raised when the data type is not understood.""" pass
[docs]class EmptySpecsError(DataError): """Error raise when an operation can't be performed because there are no specifications associated with the Data instance.""" pass
[docs]class Data: """A data manager that takes data as inputs along with data specs and transforms into primitive types for use in the optimization. Parameters ---------- data_specs A data specification object, or list of data specification objects that indicate what the columns of a data frame represent. param_set A parameter set that has covariate specifications, or list of these sets. Attributes ---------- data A dictionary of numpy ndarrays keyed by the column attribute in _data_specs, extracted from the data frame after doing self.process_data(). If _data_specs has multiple elements, then the values will be a list of numpy ndarrays, in the order of _data_specs. covariates """ def __init__(self, data_specs: Optional[Union[DataSpecs, List[DataSpecs]]] = None, param_set: Optional[Union[ParameterSet, List[ParameterSet]]] = None): self._data_specs = [] self._param_set = [] self._df = None if data_specs is not None: self.set_data_specs(data_specs) if param_set is not None: self.set_param_set(param_set) self.data: Dict[str, Union[np.ndarray, List[np.ndarray]]] = dict() self.covariates: List[Dict[str, Any]] = list() self.groups_info = defaultdict(dict) @property def data_spec_col_attributes(self): return self._data_specs[0]._col_attributes @property def _unique_covariates(self): covariates = [p_set._flat_covariates for p_set in self._param_set] return set([item for sublist in covariates for item in sublist]) @property def multi_spec(self): return len(self._data_specs) > 1 @property def multi_param_set(self): return len(self._param_set) > 1 @staticmethod def _col_to_attribute(x: str) -> str: return ''.join(x.split('col_')[1:])
[docs] def set_data_specs(self, data_specs: Union[DataSpecs, List[DataSpecs]]): """Updates the data specifications, or sets them if they are empty. Parameters ---------- data_specs A data specification object, or list of data specification objects that indicate what the columns of a data frame represent. """ if isinstance(data_specs, list): _check_compatible_specs(data_specs) self._data_specs = data_specs else: self._data_specs = [data_specs]
[docs] def set_param_set(self, param_set: Union[ParameterSet, List[ParameterSet]]): if isinstance(param_set, list): self._param_set = param_set else: self._param_set = [param_set]
[docs] def detach_data_specs(self): """Remove existing data specs.""" self._data_specs = list()
[docs] def detach_param_set(self): """Remove existing parameter set.""" self._param_set = list()
[docs] def process_data(self, df: pd.DataFrame): """Process a data frame and attach to this instance with existing data specs. Parameters ---------- df A pandas.DataFrame with all of the information that the existing data specifications needs. """ if not isinstance(df, pd.DataFrame): raise DataTypeError("Data to attach must be in the form of a pandas.DataFrame.") self._df = df.copy() if len(self._data_specs) == 0: raise EmptySpecsError("Need to attach data specs before processing data.") for spec in self._data_specs: spec._validate_df(df=self._df) for attribute in self.data_spec_col_attributes: name = self._col_to_attribute(attribute) self.data[name] = list() for spec in self._data_specs: self.data[name].append( self._df[getattr(spec, attribute)].to_numpy() ) if not self.multi_spec: self.data[name] = self.data[name][0]