dival.reconstructors.learnedpd_reconstructor module
- class dival.reconstructors.learnedpd_reconstructor.LearnedPDReconstructor(ray_trafo, **kwargs)[source]
Bases:
StandardLearnedReconstructor
CT reconstructor applying a learned primal dual iterative scheme ([1]).
References
- HYPER_PARAMS = {'batch_norm': {'default': False, 'retrain': True}, 'batch_size': {'default': 5, 'retrain': True}, 'epochs': {'default': 20, 'retrain': True}, 'init_fbp': {'default': False, 'retrain': True}, 'init_filter_type': {'default': 'Hann', 'retrain': True}, 'init_frequency_scaling': {'default': 0.4, 'retrain': True}, 'internal_ch': {'default': 32, 'retrain': True}, 'kernel_size': {'default': 3, 'retrain': True}, 'lr': {'default': 0.001, 'retrain': True}, 'lr_min': {'default': 0.0, 'retrain': True}, 'lrelu_coeff': {'default': 0.2, 'retrain': True}, 'ndual': {'default': 5, 'retrain': True}, 'niter': {'default': 10, 'retrain': True}, 'nlayer': {'default': 3, 'retrain': True}, 'normalize_by_opnorm': {'default': True, 'retrain': True}, 'nprimal': {'default': 5, 'retrain': True}, 'prelu': {'default': True, 'retrain': True}, 'use_sigmoid': {'default': False, 'retrain': True}}
Specification of hyper parameters.
This class attribute is a dict that lists the hyper parameter of the reconstructor. It should not be hidden by an instance attribute of the same name (i.e. by assigning a value to self.HYPER_PARAMS in an instance of a subtype).
Note: in order to inherit
HYPER_PARAMS
from a super class, the subclass should create a deep copy of it, i.e. executeHYPER_PARAMS = copy.deepcopy(SuperReconstructorClass.HYPER_PARAMS)
in the class body.The keys of this dict are the names of the hyper parameters, and each value is a dict with the following fields.
Standard fields:
'default'
Default value.
'retrain'
bool, optionalWhether training depends on the parameter. Default:
False
. Any custom subclass of LearnedReconstructor must set this field toTrue
if training depends on the parameter value.
Hyper parameter search fields:
'range'
(float, float), optionalInterval of valid values. If this field is set, the parameter is taken to be real-valued. Either
'range'
or'choices'
has to be set.'choices'
sequence, optionalSequence of valid values of any type. If this field is set,
'range'
is ignored. Can be used to perform manual grid search. Either'range'
or'choices'
has to be set.'method'
{‘grid_search’, ‘hyperopt’}, optionalOptimization method for the parameter. Default:
'grid_search'
. Options are:'grid_search'
Grid search over a sequence of fixed values. Can be configured by the dict
'grid_search_options'
.'hyperopt'
Random search using the
hyperopt
package. Can be configured by the dict'hyperopt_options'
.
'grid_search_options'
dictOption dict for grid search.
The following fields determine how
'range'
is sampled (in case it is specified and no'choices'
are specified):'num_samples'
int, optionalNumber of values. Default:
10
.'type'
{‘linear’, ‘logarithmic’}, optionalType of grid, i.e. distribution of the values. Default:
'linear'
. Options are:'linear'
Equidistant values in the
'range'
.'logarithmic'
Values in the
'range'
that are equidistant in the log scale.
'log_base'
int, optionalLog-base that is used if
'type'
is'logarithmic'
. Default:10.
.
'hyperopt_options'
dictOption dict for
'hyperopt'
method with the fields:'space'
hyperopt space, optionalCustom hyperopt search space. If this field is set,
'range'
and'type'
are ignored.'type'
{‘uniform’}, optionalType of the space for sampling. Default:
'uniform'
. Options are:'uniform'
Uniform distribution over the
'range'
.
- __init__(ray_trafo, **kwargs)[source]
- Parameters:
ray_trafo (
odl.tomo.RayTransform
) – Ray transform (the forward operator).super().__init__(). (Further keyword arguments are passed to)
- init_model()[source]
Initialize
model
. Called intrain()
after callinginit_transform()
, but before callinginit_optimizer()
andinit_scheduler()
.
- init_optimizer(dataset_train)[source]
Initialize the optimizer. Called in
train()
, after callinginit_transform()
andinit_model()
, but before callinginit_scheduler()
.- Parameters:
dataset_train (
torch.utils.data.Dataset
) – The training (torch) dataset constructed intrain()
.
- init_scheduler(dataset_train)[source]
Initialize the learning rate scheduler. Called in
train()
, after callinginit_transform()
,init_model()
andinit_optimizer()
.- Parameters:
dataset_train (
torch.utils.data.Dataset
) – The training (torch) dataset constructed intrain()
.
- property batch_norm
- property batch_size
- property epochs
- property init_fbp
- property init_filter_type
- property init_frequency_scaling
- property internal_ch
- property kernel_size
- property lr
- property lr_min
- property lrelu_coeff
- property ndual
- property niter
- property nlayer
- property normalize_by_opnorm
- property nprimal
- property prelu
- property use_sigmoid