dival.reconstructors.standard_learned_reconstructor module

class dival.reconstructors.standard_learned_reconstructor.StandardLearnedReconstructor(op, hyper_params=None, num_data_loader_workers=8, use_cuda=True, show_pbar=True, log_dir=None, log_num_validation_samples=0, save_best_learned_params_path=None, torch_manual_seed=1, shuffle='auto', worker_init_fn=None, **kwargs)[source]

Bases: dival.reconstructors.reconstructor.LearnedReconstructor

Standard learned reconstructor base class.

Provides a default implementation that only requires subclasses to implement init_model().

By default, the Adam optimizer is used. This can be changed by reimplementing init_optimizer(). Also, a OneCycleLR scheduler is used by default, which can be changed by reimplementing init_scheduler().

The training implementation selects the best model reached after an integer number of epochs based on the validation set.

The hyper parameter 'normalize_by_opnorm' selects whether op should be normalized by the operator norm. In this case, the inputs to model are divided by the operator norm.

model

The neural network. Must be initialized by the subclass init_model() implementation.

Type

torch.nn.Module or None

non_normed_op

The original op passed to __init__(), regardless of self.hyper_params['normalize_by_opnorm']. See also op.

Type

odl.operator.Operator

HYPER_PARAMS = {'batch_size': {'default': 64, 'retrain': True}, 'epochs': {'default': 20, 'retrain': True}, 'lr': {'default': 0.01, 'retrain': True}, 'normalize_by_opnorm': {'default': False, 'retrain': True}}
__init__(op, hyper_params=None, num_data_loader_workers=8, use_cuda=True, show_pbar=True, log_dir=None, log_num_validation_samples=0, save_best_learned_params_path=None, torch_manual_seed=1, shuffle='auto', worker_init_fn=None, **kwargs)[source]
Parameters
  • op (odl.operator.Operator) – Forward operator.

  • num_data_loader_workers (int, optional) – Number of parallel workers to use for loading data.

  • use_cuda (bool, optional) – Whether to use cuda for the U-Net.

  • show_pbar (bool, optional) – Whether to show tqdm progress bars during the epochs.

  • log_dir (str, optional) – Tensorboard log directory (name of sub-directory in utils/logs). If None, no logs are written.

  • log_num_valiation_samples (int, optional) – Number of validation images to store in tensorboard logs. This option only takes effect if log_dir is not None.

  • save_best_learned_params_path (str, optional) – Save best model weights during training under the specified path by calling save_learned_params().

  • torch_manual_seed (int, optional) – Fixed seed to set by torch.manual_seed before training. The default is 1. It can be set to None or False to disable the manual seed.

  • shuffle ({'auto', False, True}, optional) – Whether to use shuffling when loading data. When 'auto' is specified (the default), True is used iff the dataset passed to train() supports random access.

  • worker_init_fn (callable, optional) – Callable worker_init_fn passed to torch.utils.data.DataLoader.__init__(), which can be used to configure the dataset copies for different worker instances (cf. torch’s IterableDataset docs)

property opnorm
property op

odl.operator.Operator: The forward operator, normalized if self.hyper_params['normalize_by_opnorm'] is True.

eval(test_data)[source]
train(dataset)[source]

Train the reconstructor with a dataset by adapting its parameters.

Should only use the training and validation data from dataset.

Parameters

dataset (Dataset) – The dataset from which the training data should be used.

init_transform(dataset)[source]

Initialize the transform (_transform) that is applied on each training sample, e.g. for data augmentation. In the default implementation of train(), it is passed to Dataset.create_torch_dataset() when creating the training (but not the validation) torch dataset, which applies the transform to the (tuple of) torch tensor(s) right before returning, i.e. after reshaping to (1,) + orig_shape.

The default implementation of this method disables the transform by assigning None. Called in train() at the beginning, i.e. before calling init_model(), init_optimizer() and init_scheduler().

Parameters

dataset (dival.datasets.dataset.Dataset) – The dival dataset passed to train().

property transform

callable: Transform that is applied on each sample, usually set by init_transform(), which gets called in train().

init_model()[source]

Initialize model. Called in train() after calling init_transform(), but before calling init_optimizer() and init_scheduler().

init_optimizer(dataset_train)[source]

Initialize the optimizer. Called in train(), after calling init_transform() and init_model(), but before calling init_scheduler().

Parameters

dataset_train (torch.utils.data.Dataset) – The training (torch) dataset constructed in train().

property optimizer

torch.optim.Optimizer: The optimizer, usually set by init_optimizer(), which gets called in train().

init_scheduler(dataset_train)[source]

Initialize the learning rate scheduler. Called in train(), after calling init_transform(), init_model() and init_optimizer().

Parameters

dataset_train (torch.utils.data.Dataset) – The training (torch) dataset constructed in train().

property scheduler

torch learning rate scheduler: The scheduler, usually set by init_scheduler(), which gets called in train().

property batch_size
property epochs
property lr
property normalize_by_opnorm
save_learned_params(path)[source]

Save learned parameters to file.

Parameters

path (str) – Path at which the learned parameters should be saved. Implementations may interpret this as a file path or as a directory path for multiple files (which then should be created if it does not exist). If the implementation expects a file path, it should accept it without file ending.

load_learned_params(path, convert_data_parallel='auto')[source]

Load learned parameters from file.

Parameters
  • path (str) – Path at which the learned parameters are stored. Implementations may interpret this as a file path or as a directory path for multiple files. If the implementation expects a file path, it should accept it without file ending.

  • convert_data_parallel (bool or {'auto', 'keep'}, optional) –

    Whether to automatically convert the model weight names if model is a nn.DataParallel-model but the stored state dict stems from a non-data-parallel model, or vice versa.

    'auto' or True:

    Auto-convert weight names, depending on the type of model.

    'keep' or False:

    Do not convert weight names. Convert to plain weight names.