dival.reconstructors.standard_learned_reconstructor module
- class dival.reconstructors.standard_learned_reconstructor.StandardLearnedReconstructor(op, hyper_params=None, num_data_loader_workers=8, use_cuda=True, show_pbar=True, log_dir=None, log_num_validation_samples=0, save_best_learned_params_path=None, torch_manual_seed=1, shuffle='auto', worker_init_fn=None, **kwargs)[source]
Bases:
LearnedReconstructor
Standard learned reconstructor base class.
Provides a default implementation that only requires subclasses to implement
init_model()
.By default, the Adam optimizer is used. This can be changed by reimplementing
init_optimizer()
. Also, a OneCycleLR scheduler is used by default, which can be changed by reimplementinginit_scheduler()
.The training implementation selects the best model reached after an integer number of epochs based on the validation set.
The hyper parameter
'normalize_by_opnorm'
selects whetherop
should be normalized by the operator norm. In this case, the inputs tomodel
are divided by the operator norm.- model
The neural network. Must be initialized by the subclass
init_model()
implementation.- Type:
torch.nn.Module
or None
- non_normed_op
The original op passed to
__init__()
, regardless ofself.hyper_params['normalize_by_opnorm']
. See alsoop
.- Type:
odl.operator.Operator
- HYPER_PARAMS = {'batch_size': {'default': 64, 'retrain': True}, 'epochs': {'default': 20, 'retrain': True}, 'lr': {'default': 0.01, 'retrain': True}, 'normalize_by_opnorm': {'default': False, 'retrain': True}}
Specification of hyper parameters.
This class attribute is a dict that lists the hyper parameter of the reconstructor. It should not be hidden by an instance attribute of the same name (i.e. by assigning a value to self.HYPER_PARAMS in an instance of a subtype).
Note: in order to inherit
HYPER_PARAMS
from a super class, the subclass should create a deep copy of it, i.e. executeHYPER_PARAMS = copy.deepcopy(SuperReconstructorClass.HYPER_PARAMS)
in the class body.The keys of this dict are the names of the hyper parameters, and each value is a dict with the following fields.
Standard fields:
'default'
Default value.
'retrain'
bool, optionalWhether training depends on the parameter. Default:
False
. Any custom subclass of LearnedReconstructor must set this field toTrue
if training depends on the parameter value.
Hyper parameter search fields:
'range'
(float, float), optionalInterval of valid values. If this field is set, the parameter is taken to be real-valued. Either
'range'
or'choices'
has to be set.'choices'
sequence, optionalSequence of valid values of any type. If this field is set,
'range'
is ignored. Can be used to perform manual grid search. Either'range'
or'choices'
has to be set.'method'
{‘grid_search’, ‘hyperopt’}, optionalOptimization method for the parameter. Default:
'grid_search'
. Options are:'grid_search'
Grid search over a sequence of fixed values. Can be configured by the dict
'grid_search_options'
.'hyperopt'
Random search using the
hyperopt
package. Can be configured by the dict'hyperopt_options'
.
'grid_search_options'
dictOption dict for grid search.
The following fields determine how
'range'
is sampled (in case it is specified and no'choices'
are specified):'num_samples'
int, optionalNumber of values. Default:
10
.'type'
{‘linear’, ‘logarithmic’}, optionalType of grid, i.e. distribution of the values. Default:
'linear'
. Options are:'linear'
Equidistant values in the
'range'
.'logarithmic'
Values in the
'range'
that are equidistant in the log scale.
'log_base'
int, optionalLog-base that is used if
'type'
is'logarithmic'
. Default:10.
.
'hyperopt_options'
dictOption dict for
'hyperopt'
method with the fields:'space'
hyperopt space, optionalCustom hyperopt search space. If this field is set,
'range'
and'type'
are ignored.'type'
{‘uniform’}, optionalType of the space for sampling. Default:
'uniform'
. Options are:'uniform'
Uniform distribution over the
'range'
.
- __init__(op, hyper_params=None, num_data_loader_workers=8, use_cuda=True, show_pbar=True, log_dir=None, log_num_validation_samples=0, save_best_learned_params_path=None, torch_manual_seed=1, shuffle='auto', worker_init_fn=None, **kwargs)[source]
- Parameters:
op (
odl.operator.Operator
) – Forward operator.num_data_loader_workers (int, optional) – Number of parallel workers to use for loading data.
use_cuda (bool, optional) – Whether to use cuda for the U-Net.
show_pbar (bool, optional) – Whether to show tqdm progress bars during the epochs.
log_dir (str, optional) – Tensorboard log directory (name of sub-directory in utils/logs). If None, no logs are written.
log_num_valiation_samples (int, optional) – Number of validation images to store in tensorboard logs. This option only takes effect if
log_dir is not None
.save_best_learned_params_path (str, optional) – Save best model weights during training under the specified path by calling
save_learned_params()
.torch_manual_seed (int, optional) – Fixed seed to set by
torch.manual_seed
before training. The default is 1. It can be set to None or False to disable the manual seed.shuffle ({
'auto'
,False
,True
}, optional) – Whether to use shuffling when loading data. When'auto'
is specified (the default),True
is used iff the dataset passed totrain()
supports random access.worker_init_fn (callable, optional) – Callable worker_init_fn passed to
torch.utils.data.DataLoader.__init__()
, which can be used to configure the dataset copies for different worker instances (cf. torch’s IterableDataset docs)
- property opnorm
- property op
odl.operator.Operator
: The forward operator, normalized ifself.hyper_params['normalize_by_opnorm']
isTrue
.
- train(dataset)[source]
Train the reconstructor with a dataset by adapting its parameters.
Should only use the training and validation data from dataset.
- Parameters:
dataset (
Dataset
) – The dataset from which the training data should be used.
- init_transform(dataset)[source]
Initialize the transform (
_transform
) that is applied on each training sample, e.g. for data augmentation. In the default implementation oftrain()
, it is passed toDataset.create_torch_dataset()
when creating the training (but not the validation) torch dataset, which applies the transform to the (tuple of) torch tensor(s) right before returning, i.e. after reshaping to(1,) + orig_shape
.The default implementation of this method disables the transform by assigning None. Called in
train()
at the beginning, i.e. before callinginit_model()
,init_optimizer()
andinit_scheduler()
.- Parameters:
dataset (
dival.datasets.dataset.Dataset
) – The dival dataset passed totrain()
.
- property transform
callable: Transform that is applied on each sample, usually set by
init_transform()
, which gets called intrain()
.
- init_model()[source]
Initialize
model
. Called intrain()
after callinginit_transform()
, but before callinginit_optimizer()
andinit_scheduler()
.
- init_optimizer(dataset_train)[source]
Initialize the optimizer. Called in
train()
, after callinginit_transform()
andinit_model()
, but before callinginit_scheduler()
.- Parameters:
dataset_train (
torch.utils.data.Dataset
) – The training (torch) dataset constructed intrain()
.
- property optimizer
torch.optim.Optimizer
: The optimizer, usually set byinit_optimizer()
, which gets called intrain()
.
- init_scheduler(dataset_train)[source]
Initialize the learning rate scheduler. Called in
train()
, after callinginit_transform()
,init_model()
andinit_optimizer()
.- Parameters:
dataset_train (
torch.utils.data.Dataset
) – The training (torch) dataset constructed intrain()
.
- property scheduler
torch learning rate scheduler: The scheduler, usually set by
init_scheduler()
, which gets called intrain()
.
- property batch_size
- property epochs
- property lr
- property normalize_by_opnorm
- save_learned_params(path)[source]
Save learned parameters to file.
- Parameters:
path (str) – Path at which the learned parameters should be saved. Implementations may interpret this as a file path or as a directory path for multiple files (which then should be created if it does not exist). If the implementation expects a file path, it should accept it without file ending.
- load_learned_params(path, convert_data_parallel='auto')[source]
Load learned parameters from file.
- Parameters:
path (str) – Path at which the learned parameters are stored. Implementations may interpret this as a file path or as a directory path for multiple files. If the implementation expects a file path, it should accept it without file ending.
convert_data_parallel (bool or {
'auto'
,'keep'
}, optional) –Whether to automatically convert the model weight names if
model
is ann.DataParallel
-model but the stored state dict stems from a non-data-parallel model, or vice versa.'auto'
orTrue
:Auto-convert weight names, depending on the type of
model
.'keep'
orFalse
:Do not convert weight names. Convert to plain weight names.