Source code for dival.reconstructors.reconstructor

# -*- coding: utf-8 -*-
"""Provides the abstract reconstructor base class."""
import os
from inspect import signature, Parameter
import json
from warnings import warn
from copy import deepcopy


class _ReconstructorMeta(type):
    def __init__(cls, name, bases, dct):
        def get_fget(k):
            def fget(self):
                return self.hyper_params[k]
            return fget

        def get_fset(k):
            def fset(self, v):
                self.hyper_params[k] = v
            return fset

        for k in cls.HYPER_PARAMS.keys():
            if k.isidentifier():
                fget = get_fget(k)
                fset = get_fset(k)
                setattr(cls, '_fget_{}'.format(k), fget)
                setattr(cls, '_fset_{}'.format(k), fset)
                setattr(cls, k, property(fget, fset))


[docs]class Reconstructor(metaclass=_ReconstructorMeta): """Abstract reconstructor base class. There are two ways of implementing a `Reconstructor` subclass: * Implement :meth:`reconstruct`. It has to support optional in-place and out-of-place evaluation. * Implement :meth:`_reconstruct`. It must have one of the following signatures: - ``_reconstruct(self, observation, out)`` (in-place) - ``_reconstruct(self, observation)`` (out-of-place) - ``_reconstruct(self, observation, out=None)`` (optional in-place) The class attribute :attr:`HYPER_PARAMS` defines the hyper parameters of the reconstructor class. The current values for a reconstructor instance are given by the attribute :attr:`hyper_params`. Properties wrapping :attr:`hyper_params` are automatically created by the metaclass (for hyper parameter names that are valid identifiers), such that the hyper parameters can be written and read like instance attributes. Attributes ---------- reco_space : :class:`odl.discr.DiscretizedSpace`, optional Reconstruction space. observation_space : :class:`odl.discr.DiscretizedSpace`, optional Observation space. name : str Name of the reconstructor. hyper_params : dict Current hyper parameter values. Initialized automatically using the default values from :attr:`HYPER_PARAMS` (but may be overridden by `hyper_params` passed to :meth:`__init__`). It is expected to have the same keys as :attr:`HYPER_PARAMS`. The values for these keys in this dict are wrapped by properties with the key as identifier (if possible), so an assignment to the property changes the value in this dict and vice versa. """ HYPER_PARAMS = {} """Specification of hyper parameters. This class attribute is a dict that lists the hyper parameter of the reconstructor. It should not be hidden by an instance attribute of the same name (i.e. by assigning a value to `self.HYPER_PARAMS` in an instance of a subtype). *Note:* in order to inherit :attr:`HYPER_PARAMS` from a super class, the subclass should create a deep copy of it, i.e. execute ``HYPER_PARAMS = copy.deepcopy(SuperReconstructorClass.HYPER_PARAMS)`` in the class body. The keys of this dict are the names of the hyper parameters, and each value is a dict with the following fields. Standard fields: ``'default'`` Default value. ``'retrain'`` : bool, optional Whether training depends on the parameter. Default: ``False``. Any custom subclass of `LearnedReconstructor` must set this field to ``True`` if training depends on the parameter value. Hyper parameter search fields: ``'range'`` : (float, float), optional Interval of valid values. If this field is set, the parameter is taken to be real-valued. Either ``'range'`` or ``'choices'`` has to be set. ``'choices'`` : sequence, optional Sequence of valid values of any type. If this field is set, ``'range'`` is ignored. Can be used to perform manual grid search. Either ``'range'`` or ``'choices'`` has to be set. ``'method'`` : {'grid_search', 'hyperopt'}, optional Optimization method for the parameter. Default: ``'grid_search'``. Options are: ``'grid_search'`` Grid search over a sequence of fixed values. Can be configured by the dict ``'grid_search_options'``. ``'hyperopt'`` Random search using the ``hyperopt`` package. Can be configured by the dict ``'hyperopt_options'``. ``'grid_search_options'`` : dict Option dict for grid search. The following fields determine how ``'range'`` is sampled (in case it is specified and no ``'choices'`` are specified): ``'num_samples'`` : int, optional Number of values. Default: ``10``. ``'type'`` : {'linear', 'logarithmic'}, optional Type of grid, i.e. distribution of the values. Default: ``'linear'``. Options are: ``'linear'`` Equidistant values in the ``'range'``. ``'logarithmic'`` Values in the ``'range'`` that are equidistant in the log scale. ``'log_base'`` : int, optional Log-base that is used if ``'type'`` is ``'logarithmic'``. Default: ``10.``. ``'hyperopt_options'`` : dict Option dict for ``'hyperopt'`` method with the fields: ``'space'`` : hyperopt space, optional Custom hyperopt search space. If this field is set, ``'range'`` and ``'type'`` are ignored. ``'type'`` : {'uniform'}, optional Type of the space for sampling. Default: ``'uniform'``. Options are: ``'uniform'`` Uniform distribution over the ``'range'``. """
[docs] def __init__(self, reco_space=None, observation_space=None, name='', hyper_params=None): self.reco_space = reco_space self.observation_space = observation_space self.name = name or self.__class__.__name__ self.hyper_params = {k: v['default'] for k, v in self.HYPER_PARAMS.items()} if hyper_params is not None: self.hyper_params.update(hyper_params)
[docs] def reconstruct(self, observation, out=None): """Reconstruct input data from observation data. The default implementation calls `_reconstruct`, automatically choosing in-place or out-of-place evaluation. Parameters ---------- observation : :attr:`observation_space` element-like The observation data. out : :attr:`reco_space` element-like, optional Array to which the result is written (in-place evaluation). If `None`, a new array is created (out-of-place evaluation). If `None`, the new array is initialized with zero before calling :meth:`_reconstruct`. Returns ------- reconstruction : :attr:`reco_space` element or `out` The reconstruction. """ parameters = signature(self._reconstruct).parameters if 'out' in parameters: if out is not None: self._reconstruct(observation, out) reco = out elif parameters['out'].default == Parameter.empty: reco = self.reco_space.zero() self._reconstruct(observation, reco) else: reco = self._reconstruct(observation) else: reco = self._reconstruct(observation) if out is not None: out[:] = reco reco = out return reco
def _reconstruct(self, observation, *args, **kwargs): """Reconstruct input data from observation data. This method must have one of the following signatures: - ``_reconstruct(self, observation, out)`` (in-place) - ``_reconstruct(self, observation)`` (out-of-place) - ``_reconstruct(self, observation, out=None)`` (optional in-place) The parameters and return value are documented in :meth:`reconstruct`. """ raise NotImplementedError("'_reconstruct' not implemented by class " "'{}'. Reconstructor subclasses must " "implement either 'reconstruct' or " "'_reconstruct'.".format(type(self)))
[docs] def save_hyper_params(self, path): """Save hyper parameters to JSON file. See also :meth:`load_hyper_params`. Parameters ---------- path : str Path of the file in which the hyper parameters should be saved. The ending ``'.json'`` is automatically appended if not included. """ path = os.path.splitext(path)[0] + '.json' with open(path, 'w') as f: json.dump(self.hyper_params, f, indent=1)
[docs] def load_hyper_params(self, path): """Load hyper parameters from JSON file. See also :meth:`save_hyper_params`. Parameters ---------- path : str Path of the file in which the hyper parameters are stored. The ending ``'.json'`` is automatically appended if not included. """ path = os.path.splitext(path)[0] + '.json' with open(path, 'r') as f: hyper_params = json.load(f) for k, v in hyper_params.items(): if k not in self.HYPER_PARAMS: warn("loading value for unknown hyper parameter '{}'" .format(k)) self.hyper_params.update(hyper_params)
[docs] def save_params(self, path=None, hyper_params_path=None): """Save all parameters to file. E.g. for learned reconstructors, both hyper parameters and learned parameters should be included. The purpose of this method, together with :meth:`load_params`, is to define a unified way of saving and loading any kind of reconstructor. The default implementation calls :meth:`save_hyper_params`. Subclasses must reimplement this method in order to include non-hyper parameters. Implementations should derive a sensible default for `hyper_params_path` from `path`, such that all parameters can be saved and loaded by specifying only `path`. Recommended patterns are: - if non-hyper parameters are stored in a single file and `path` specifies it without file ending: ``hyper_params_path=path + '_hyper_params.json'`` - if non-hyper parameters are stored in a directory: ``hyper_params_path=os.path.join(path, 'hyper_params.json')``. - if there are no non-hyper parameters, this default implementation can be used: ``hyper_params_path=path + '_hyper_params.json'`` Parameters ---------- path : str[, optional] Path at which all (non-hyper) parameters should be saved. This argument is required if the reconstructor has non-hyper parameters or hyper_params_path is omitted. If the reconstructor has non-hyper parameters, the implementation may interpret it as a file path or as a directory path for multiple files (the dir should be created by this method if it does not exist). If the implementation expects a file path, it should accept it without file ending. hyper_params_path : str, optional Path of the file in which the hyper parameters should be saved. The ending ``'.json'`` is automatically appended if not included. If not specified, it should be determined from `path` (see method description above). The default implementation saves to the file ``path + '_hyper_params.json'``. """ hp_path = hyper_params_path if hp_path is None: if path is None: raise ValueError( 'either `path` or `hyper_params_path` required (in ' 'default implementation of `Reconstructor.save_params`)') hp_path = path + '_hyper_params.json' else: hp_path = (hyper_params_path if hyper_params_path.endswith('.json') else hyper_params_path + '.json') self.save_hyper_params(hp_path)
[docs] def load_params(self, path=None, hyper_params_path=None): """Load of parameters from file. E.g. for learned reconstructors, both hyper parameters and learned parameters should be included. The purpose of this method, together with :meth:`save_params`, is to define a unified way of saving and loading any kind of reconstructor. The default implementation calls :meth:`load_hyper_params`. Subclasses must reimplement this method in order to include non-hyper parameters. See :meth:`save_params` for recommended patterns to derive a default `hyper_params_path` from `path`. Parameters ---------- path : str[, optional] Path at which all (non-hyper) parameters are stored. This argument is required if the reconstructor has non-hyper parameters or hyper_params_path is omitted. If the reconstructor has non-hyper parameters, the implementation may interpret it as a file path or as a directory path for multiple files. If the implementation expects a file path, it should accept it without file ending. hyper_params_path : str, optional Path of the file in which the hyper parameters are stored. The ending ``'.json'`` is automatically appended if not included. If not specified, it should be determined from `path` (see description of :meth:`save_params`). The default implementation reads from the file ``path + '_hyper_params.json'``. """ hp_path = hyper_params_path if hp_path is None: if path is None: raise ValueError( 'either `path` or `hyper_params_path` required (in ' 'default implementation of `Reconstructor.save_params`)') hp_path = path + '_hyper_params.json' else: hp_path = (hyper_params_path if hyper_params_path.endswith('.json') else hyper_params_path + '.json') self.load_hyper_params(hp_path)
[docs]class LearnedReconstructor(Reconstructor):
[docs] def train(self, dataset): """Train the reconstructor with a dataset by adapting its parameters. Should only use the training and validation data from `dataset`. Parameters ---------- dataset : :class:`.Dataset` The dataset from which the training data should be used. """ raise NotImplementedError
[docs] def save_params(self, path, hyper_params_path=None): """Save all parameters to file. Calls :meth:`save_hyper_params` and :meth:`save_learned_params`, where :meth:`save_learned_params` should be implemented by the subclass. This implementation assumes that `path` is interpreted as a single file name, preferably specified without file ending. If `path` is a directory, the subclass needs to reimplement this method in order to follow the recommended default value pattern: ``hyper_params_path=os.path.join(path, 'hyper_params.json')``. Parameters ---------- path : str Path at which the learned parameters should be saved. Passed to :meth:`save_learned_params`. If the implementation interprets it as a file path, it is preferred to exclude the file ending (otherwise the default value of `hyper_params_path` is suboptimal). hyper_params_path : str, optional Path of the file in which the hyper parameters should be saved. The ending ``'.json'`` is automatically appended if not included. If not specified, this implementation saves to the file ``path + '_hyper_params.json'``. """ hp_path = hyper_params_path if hp_path is None: hp_path = path + '_hyper_params.json' else: hp_path = (hyper_params_path if hyper_params_path.endswith('.json') else hyper_params_path + '.json') self.save_hyper_params(hp_path) self.save_learned_params(path)
[docs] def load_params(self, path, hyper_params_path=None): """Load all parameters from file. Calls :meth:`load_hyper_params` and :meth:`load_learned_params`, where :meth:`load_learned_params` should be implemented by the subclass. This implementation assumes that `path` is interpreted as a single file name, preferably specified without file ending. If `path` is a directory, the subclass needs to reimplement this method in order to follow the recommended default value pattern: ``hyper_params_path=os.path.join(path, 'hyper_params.json')``. Parameters ---------- path : str Path at which the parameters are stored. Passed to :meth:`load_learned_params`. If the implementation interprets it as a file path, it is preferred to exclude the file ending (otherwise the default value of `hyper_params_path` is suboptimal). hyper_params_path : str, optional Path of the file in which the hyper parameters are stored. The ending ``'.json'`` is automatically appended if not included. If not specified, this implementation reads from the file ``path + '_hyper_params.json'``. """ hp_path = hyper_params_path if hp_path is None: hp_path = path + '_hyper_params.json' else: hp_path = (hyper_params_path if hyper_params_path.endswith('.json') else hyper_params_path + '.json') self.load_hyper_params(hp_path) self.load_learned_params(path)
[docs] def save_learned_params(self, path): """Save learned parameters to file. Parameters ---------- path : str Path at which the learned parameters should be saved. Implementations may interpret this as a file path or as a directory path for multiple files (which then should be created if it does not exist). If the implementation expects a file path, it should accept it without file ending. """ raise NotImplementedError
[docs] def load_learned_params(self, path): """Load learned parameters from file. Parameters ---------- path : str Path at which the learned parameters are stored. Implementations may interpret this as a file path or as a directory path for multiple files. If the implementation expects a file path, it should accept it without file ending. """ raise NotImplementedError
[docs]class IterativeReconstructor(Reconstructor): """Iterative reconstructor base class. It is recommended to use :class:`StandardIterativeReconstructor` as a base class for iterative reconstructors if suitable, which provides some default implementation. Subclasses must call :attr:`callback` after each iteration in ``self.reconstruct``. This is e.g. required by the :mod:`~dival.evaluation` module. Attributes ---------- callback : ``odl.solvers.util.callback.Callback`` or `None` Callback to be called after each iteration. """ HYPER_PARAMS = deepcopy(Reconstructor.HYPER_PARAMS) HYPER_PARAMS.update({ 'iterations': { 'default': 100, 'retrain': False } })
[docs] def __init__(self, callback=None, **kwargs): """ Parameters ---------- callback : ``odl.solvers.util.callback.Callback``, optional Callback to be called after each iteration. """ self.callback = callback super().__init__(**kwargs)
[docs] def reconstruct(self, observation, out=None, callback=None): """Reconstruct input data from observation data. Same as :meth:`Reconstructor.reconstruct`, but with additional optional `callback` parameter. Parameters ---------- observation : :attr:`observation_space` element-like The observation data. out : :attr:`reco_space` element-like, optional Array to which the result is written (in-place evaluation). If `None`, a new array is created (out-of-place evaluation). callback : ``odl.solvers.util.callback.Callback``, optional Additional callback for this reconstruction that is temporarily composed with :attr:`callback`, i.e. also called after each iteration. If `None`, just :attr:`callback` is called. Returns ------- reconstruction : :attr:`reco_space` element or `out` The reconstruction. """ if callback is not None: orig_callback = self.callback self.callback = (callback if self.callback is None else self.callback & callback) reconstruction = super().reconstruct(observation, out=out) if callback is not None: self.callback = orig_callback return reconstruction
[docs]class StandardIterativeReconstructor(IterativeReconstructor): """Standard iterative reconstructor base class. Provides a default implementation that only requires subclasses to implement :meth:`_compute_iterate` and optionally :meth:`_setup`. Attributes ---------- x0 : :attr:`reco_space` element-like or `None` Default initial value for the iterative reconstruction. Can be overridden by passing a different ``x0`` to :meth:`reconstruct`. callback : ``odl.solvers.util.callback.Callback`` or `None` Callback that is called after each iteration. """
[docs] def __init__(self, x0=None, callback=None, **kwargs): """ Parameters ---------- x0 : :attr:`reco_space` element-like, optional Default initial value for the iterative reconstruction. Can be overridden by passing a different ``x0`` to :meth:`reconstruct`. callback : ``odl.solvers.util.callback.Callback``, optional Callback that is called after each iteration. """ self.x0 = x0 super().__init__(callback=callback, **kwargs)
def _setup(self, observation): """Setup before iteration process. Called by the default implementation of :meth:`_reconstruct` in the beginning, i.e. before computing the first iterate. Parameters ---------- observation : :attr:`observation_space` element-like The observation data (forwarded from :meth:`reconstruct`). """ pass def _compute_iterate(self, observation, reco_previous, out): """Compute next iterate. This method implements the iteration step in the default implementation of :meth:`_reconstruct`. Parameters ---------- observation : :attr:`observation_space` element-like The observation data (forwarded from :meth:`reconstruct`). reco_previous : :attr:`reco_space` element-like The previous iterate value. out : :attr:`reco_space` element-like Array to which the iterate value is written. """ raise NotImplementedError
[docs] def reconstruct(self, observation, out=None, x0=None, last_iter=0, callback=None): """Reconstruct input data from observation data. Same as :meth:`Reconstructor.reconstruct`, but with additional options for iterative reconstructors. Parameters ---------- observation : :attr:`observation_space` element-like The observation data. out : :attr:`reco_space` element-like, optional Array to which the result is written (in-place evaluation). If `None`, a new array is created (out-of-place evaluation). x0 : :attr:`reco_space` element-like, optional Initial value for the iterative reconstruction. Overrides the attribute :attr:`x0`, which can be set when calling :meth:`__init__`. If both :attr:`x0` and this argument are `None`, the default implementation uses the value of `out` if called in-place, or zero if called out-of-place. last_iter : int, optional If `x0` is the result of an iteration by this method, this can be used to specify the number of iterations so far. The number of iterations for the current call is ``self.hyper_params['iterations'] - last_iter``. callback : ``odl.solvers.util.callback.Callback``, optional Additional callback for this reconstruction that is temporarily composed with :attr:`callback`, i.e. also called after each iteration. If `None`, just :attr:`callback` is called. Returns ------- reconstruction : :attr:`reco_space` element or `out` The reconstruction. """ self._x0_override = x0 self._last_iter = last_iter return super().reconstruct(observation, out=out, callback=callback)
def _reconstruct(self, observation, out): self._setup(observation) x = out if self._x0_override is not None: x[:] = self._x0_override # override for specific reconstruction elif self.x0 is not None: x[:] = self.x0 # default init value # keep value of `out` if no `x0` was specified for i in range(self.hyper_params['iterations'] - self._last_iter): self._compute_iterate(observation, reco_previous=x.copy(), out=x) if self.callback is not None: self.callback(x)
[docs]class FunctionReconstructor(Reconstructor): """Reconstructor defined by a callable. Attributes ---------- function : callable Callable that is used in `reconstruct`. fun_args : list Arguments to be passed to `function`. fun_kwargs : dict Keyword arguments to be passed to `function`. """
[docs] def __init__(self, function, *args, fun_args=None, fun_kwargs=None, **kwargs): """ Parameters ---------- function : callable Callable that is used in :meth:`reconstruct`. fun_args : list, optional Arguments to be passed to `function`. fun_kwargs : dict, optional Keyword arguments to be passed to `function`. """ super().__init__(*args, **kwargs) self.function = function self.fun_args = fun_args or [] self.fun_kwargs = fun_kwargs or {}
def _reconstruct(self, observation): return self.function(observation, *self.fun_args, **self.fun_kwargs)