dival.reconstructors.standard_learned_reconstructor module¶
-
class
dival.reconstructors.standard_learned_reconstructor.
StandardLearnedReconstructor
(op, hyper_params=None, num_data_loader_workers=8, use_cuda=True, show_pbar=True, log_dir=None, log_num_validation_samples=0, save_best_learned_params_path=None, torch_manual_seed=1, shuffle='auto', worker_init_fn=None, **kwargs)[source]¶ Bases:
dival.reconstructors.reconstructor.LearnedReconstructor
Standard learned reconstructor base class.
Provides a default implementation that only requires subclasses to implement
init_model()
.By default, the Adam optimizer is used. This can be changed by reimplementing
init_optimizer()
. Also, a OneCycleLR scheduler is used by default, which can be changed by reimplementinginit_scheduler()
.The training implementation selects the best model reached after an integer number of epochs based on the validation set.
The hyper parameter
'normalize_by_opnorm'
selects whetherop
should be normalized by the operator norm. In this case, the inputs tomodel
are divided by the operator norm.-
model
¶ The neural network. Must be initialized by the subclass
init_model()
implementation.- Type
torch.nn.Module
or None
-
non_normed_op
¶ The original op passed to
__init__()
, regardless ofself.hyper_params['normalize_by_opnorm']
. See alsoop
.- Type
odl.operator.Operator
-
HYPER_PARAMS
= {'batch_size': {'default': 64, 'retrain': True}, 'epochs': {'default': 20, 'retrain': True}, 'lr': {'default': 0.01, 'retrain': True}, 'normalize_by_opnorm': {'default': False, 'retrain': True}}¶
-
__init__
(op, hyper_params=None, num_data_loader_workers=8, use_cuda=True, show_pbar=True, log_dir=None, log_num_validation_samples=0, save_best_learned_params_path=None, torch_manual_seed=1, shuffle='auto', worker_init_fn=None, **kwargs)[source]¶ - Parameters
op (
odl.operator.Operator
) – Forward operator.num_data_loader_workers (int, optional) – Number of parallel workers to use for loading data.
use_cuda (bool, optional) – Whether to use cuda for the U-Net.
show_pbar (bool, optional) – Whether to show tqdm progress bars during the epochs.
log_dir (str, optional) – Tensorboard log directory (name of sub-directory in utils/logs). If None, no logs are written.
log_num_valiation_samples (int, optional) – Number of validation images to store in tensorboard logs. This option only takes effect if
log_dir is not None
.save_best_learned_params_path (str, optional) – Save best model weights during training under the specified path by calling
save_learned_params()
.torch_manual_seed (int, optional) – Fixed seed to set by
torch.manual_seed
before training. The default is 1. It can be set to None or False to disable the manual seed.shuffle ({
'auto'
,False
,True
}, optional) – Whether to use shuffling when loading data. When'auto'
is specified (the default),True
is used iff the dataset passed totrain()
supports random access.worker_init_fn (callable, optional) – Callable worker_init_fn passed to
torch.utils.data.DataLoader.__init__()
, which can be used to configure the dataset copies for different worker instances (cf. torch’s IterableDataset docs)
-
property
opnorm
¶
-
property
op
¶ odl.operator.Operator
: The forward operator, normalized ifself.hyper_params['normalize_by_opnorm']
isTrue
.
-
train
(dataset)[source]¶ Train the reconstructor with a dataset by adapting its parameters.
Should only use the training and validation data from dataset.
- Parameters
dataset (
Dataset
) – The dataset from which the training data should be used.
-
init_transform
(dataset)[source]¶ Initialize the transform (
_transform
) that is applied on each training sample, e.g. for data augmentation. In the default implementation oftrain()
, it is passed toDataset.create_torch_dataset()
when creating the training (but not the validation) torch dataset, which applies the transform to the (tuple of) torch tensor(s) right before returning, i.e. after reshaping to(1,) + orig_shape
.The default implementation of this method disables the transform by assigning None. Called in
train()
at the beginning, i.e. before callinginit_model()
,init_optimizer()
andinit_scheduler()
.- Parameters
dataset (
dival.datasets.dataset.Dataset
) – The dival dataset passed totrain()
.
-
property
transform
¶ callable: Transform that is applied on each sample, usually set by
init_transform()
, which gets called intrain()
.
-
init_model
()[source]¶ Initialize
model
. Called intrain()
after callinginit_transform()
, but before callinginit_optimizer()
andinit_scheduler()
.
-
init_optimizer
(dataset_train)[source]¶ Initialize the optimizer. Called in
train()
, after callinginit_transform()
andinit_model()
, but before callinginit_scheduler()
.- Parameters
dataset_train (
torch.utils.data.Dataset
) – The training (torch) dataset constructed intrain()
.
-
property
optimizer
¶ torch.optim.Optimizer
: The optimizer, usually set byinit_optimizer()
, which gets called intrain()
.
-
init_scheduler
(dataset_train)[source]¶ Initialize the learning rate scheduler. Called in
train()
, after callinginit_transform()
,init_model()
andinit_optimizer()
.- Parameters
dataset_train (
torch.utils.data.Dataset
) – The training (torch) dataset constructed intrain()
.
-
property
scheduler
¶ torch learning rate scheduler: The scheduler, usually set by
init_scheduler()
, which gets called intrain()
.
-
property
batch_size
¶
-
property
epochs
¶
-
property
lr
¶
-
property
normalize_by_opnorm
¶
-
save_learned_params
(path)[source]¶ Save learned parameters to file.
- Parameters
path (str) – Path at which the learned parameters should be saved. Implementations may interpret this as a file path or as a directory path for multiple files (which then should be created if it does not exist). If the implementation expects a file path, it should accept it without file ending.
-
load_learned_params
(path, convert_data_parallel='auto')[source]¶ Load learned parameters from file.
- Parameters
path (str) – Path at which the learned parameters are stored. Implementations may interpret this as a file path or as a directory path for multiple files. If the implementation expects a file path, it should accept it without file ending.
convert_data_parallel (bool or {
'auto'
,'keep'
}, optional) –Whether to automatically convert the model weight names if
model
is ann.DataParallel
-model but the stored state dict stems from a non-data-parallel model, or vice versa.'auto'
orTrue
:Auto-convert weight names, depending on the type of
model
.'keep'
orFalse
:Do not convert weight names. Convert to plain weight names.
-