Source code for dival.reconstructors.regression_reconstructors

# -*- coding: utf-8 -*-
"""Provides reconstructors performing regression."""
import os
import json
import numpy as np
from sklearn.linear_model import Ridge
from dival.util.odl_utility import uniform_discr_element
from dival import LearnedReconstructor


[docs]class LinRegReconstructor(LearnedReconstructor): HYPER_PARAMS = { 'l2_regularization': {'default': 0., 'range': [0., np.inf], 'retrain': True} } """Reconstructor learning and applying linear regression. Assumes the inverse operator is linear, i.e. ``x = A_inv * y``. Learns the entries of ``A_inv`` by l2-regularized linear regression: ``A_inv = 1/2N * sum_i ||x_i - A_inv * y_i||^2 + alpha/2 * ||A_inv||_F^2``, where (y_i, x_i) with i=1,...,N are pairs of observations and the corresponding ground truth. Attributes ---------- weights : :class:`np.ndarray` The weight matrix. """
[docs] def __init__(self, hyper_params=None, **kwargs): """ Parameters ---------- hyper_params : dict, optional A dict with no items or an item ``'l2_regularization': float``. Cf. :meth:`Reconstructor.init`. """ super().__init__(hyper_params=hyper_params, **kwargs) self.weights = None
def _reconstruct(self, observation): reconstruction = np.dot(self.weights, observation) return self.reco_space.element(reconstruction)
[docs] def train(self, dataset): observation_shape, reco_shape = dataset.get_shape() if (self.observation_space is not None and self.observation_space.shape != observation_shape): raise ValueError('Observation shape of dataset not matching ' '`self.observation_space.shape`') if (self.observation_space is not None and self.observation_space.shape != observation_shape): raise ValueError('Observation shape of dataset not matching ' '`self.observation_space.shape`') ridge = Ridge(self.hyper_params['l2_regularization']) train_len = dataset.get_train_len() n_features = dataset.shape[0][0] n_targets = dataset.shape[1][0] x = np.empty((train_len, n_features)) y = np.empty((train_len, n_targets)) for i, (x_i, y_i) in enumerate(dataset.get_train_generator()): x[i] = x_i y[i] = y_i ridge.fit(x, y) self.weights = ridge.coef_
[docs] def save_params(self, path): """ Save :attr:`weights` and :attr:`hyper_params` to files. Parameters ---------- path : str Folder. """ if not os.path.isdir(path): os.mkdir(path) np.save(os.path.join(path, 'weights.npy'), self.weights) with open(os.path.join(path, 'hyper_params.json'), 'w') as file: json.dump(self.hyper_params, file, indent=True)
[docs] def load_params(self, path): """ Load :attr:`weights` and :attr:`hyper_params` from files. Parameters ---------- path : str Folder. """ self.weights = np.load(os.path.join(path, 'weights.npy')) with open(os.path.join(path, 'hyper_params.json'), 'r') as file: self.hyper_params.update(json.load(file))