dival.datasets.dataset module

Provides the dataset base classes.

class dival.datasets.dataset.Dataset(space=None)[source]

Bases: object

Dataset base class.

Subclasses must either implement generator() or provide random access by implementing get_sample() and get_samples() (which then should be indicated by setting the attribute random_access = True).

space

The spaces of the elements of samples as a tuple. If only one element per sample is provided, this attribute is the space of the element (i.e., no tuple). It is strongly recommended to set this attribute in subclasses, as some functionality may depend on it.

Type

[tuple of ] odl.space.base_tensors.TensorSpace or None

shape

The shapes of the elements of samples as a tuple of tuple of int. If only one element per sample is provided, this attribute is the shape of the element (i.e., not a tuple of tuple of int, but a tuple of int).

Type

[tuple of ] tuple of int, optional

train_len

Number of training samples.

Type

int, optional

validation_len

Number of validation samples.

Type

int, optional

test_len

Number of test samples.

Type

int, optional

random_access

Whether the dataset supports random access via self.get_sample and self.get_samples. Setting this attribute is the preferred way for subclasses to indicate whether they support random access.

Type

bool, optional

num_elements_per_sample

Number of elements per sample. E.g. 1 for a ground truth dataset or 2 for a dataset of pairs of observation and ground truth.

Type

int, optional

standard_dataset_name

Datasets returned by get_standard_dataset have this attribute giving its name.

Type

str, optional

__init__(space=None)[source]

The attributes that potentially should be set by the subclass are: space (can also be set by argument), shape, train_len, validation_len, test_len, random_access and num_elements_per_sample.

Parameters

space ([tuple of ] odl.space.base_tensors.TensorSpace, optional) – The spaces of the elements of samples as a tuple. If only one element per sample is provided, this attribute is the space of the element (i.e., no tuple). It is strongly recommended to set space in subclasses, as some functionality may depend on it.

generator(part='train')[source]

Yield data.

The default implementation calls get_sample() if the dataset implements it (i.e., supports random access).

Parameters

part ({'train', 'validation', 'test'}, optional) – Whether to yield train, validation or test data. Default is 'train'.

Yields

data (odl element or tuple of odl elements) – Sample of the dataset.

get_train_generator()[source]
get_validation_generator()[source]
get_test_generator()[source]
get_len(part='train')[source]

Return the number of elements the generator will yield.

Parameters

part ({'train', 'validation', 'test'}, optional) – Whether to return the number of train, validation or test elements. Default is 'train'.

get_train_len()[source]

Return the number of samples the train generator will yield.

get_validation_len()[source]

Return the number of samples the validation generator will yield.

get_test_len()[source]

Return the number of samples the test generator will yield.

get_shape()[source]

Return the shape of each element.

Returns shape if it is set. Otherwise, it is inferred from space (which is strongly recommended to be set in every subclass). If also space is not set, a NotImplementedError is raised.

Returns

shape

Return type

[tuple of ] tuple

get_num_elements_per_sample()[source]

Return number of elements per sample.

Returns num_elements_per_sample if it is set. Otherwise, it is inferred from space (which is strongly recommended to be set in every subclass). If also space is not set, a NotImplementedError is raised.

Returns

num_elements_per_sample

Return type

int

get_data_pairs(part='train', n=None)[source]

Return first samples from data part as DataPairs object.

Only supports datasets with two elements per sample.``

Parameters
  • part ({'train', 'validation', 'test'}, optional) – The data part. Default is 'train'.

  • n (int, optional) – Number of pairs (from beginning). If None, all available data is used (the default).

get_data_pairs_per_index(part='train', index=None)[source]

Return specific samples from data part as DataPairs object.

Only supports datasets with two elements per sample.

For datasets not supporting random access, samples are extracted from generator(), which can be computationally expensive.

Parameters
  • part ({'train', 'validation', 'test'}, optional) – The data part. Default is 'train'.

  • index (int or list of int, optional) – Indices of the samples in the data part. Default is '[0]'.

create_torch_dataset(part='train', reshape=None, transform=None)[source]

Create a torch dataset wrapper for one part of this dataset.

If supports_random_access() returns False, a subclass of of torch.utils.data.IterableDataset is returned that fetches samples via generator(). Note: When using torch’s DataLoader with multiple workers you might want to individually configure the datasets for each worker, see the PyTorch docs on IterableDataset. For this purpose it can be useful to modify the wrapped dival dataset in worker_init_fn(), which can be accessed there via torch.utils.data.get_worker_info().dataset.dataset.

If supports_random_access() returns True, a subclass of of torch.utils.data.Dataset is returned that retrieves samples using get_sample().

Parameters
  • part ({'train', 'validation', 'test'}, optional) – The data part. Default is 'train'.

  • reshape (tuple of (tuple or None), optional) – Shapes to which the elements of each sample will be reshaped. If None is passed for an element, no reshape is applied.

  • transform (callable, optional) – Transform to be applied on each sample, useful for augmentation. Default: None, i.e. no transform.

Returns

dataset – The torch dataset wrapping this dataset. The wrapped dival dataset is assigned to the attribute dataset.dataset.

Return type

torch.utils.data.Dataset or torch.utils.data.IterableDataset

create_keras_generator(part='train', batch_size=1, shuffle=True, reshape=None)[source]

Create a keras data generator wrapper for one part of this dataset.

If supports_random_access() returns False, a generator wrapping generator() is returned. In this case no shuffling is performed regardless of the passed shuffle parameter. Also, parallel data loading (with multiple workers) is not applicable.

If supports_random_access() returns True, a tf.keras.utils.Sequence is returned, which is implemented using get_sample(). For datasets that support parallel calls to get_sample(), the returned data generator (sequence) can be used by multiple workers.

Parameters
  • part ({'train', 'validation', 'test'}, optional) – The data part. Default is 'train'.

  • batch_size (int, optional) – Batch size. Default is 1.

  • shuffle (bool, optional) – Whether to shuffle samples each epoch. This option has no effect if supports_random_access() returns False, since in that case samples are fetched directly from generator(). The default is True.

  • reshape (tuple of (tuple or None), optional) – Shapes to which the elements of each sample will be reshaped. If None is passed for an element, no reshape is applied.

get_sample(index, part='train', out=None)[source]

Get single sample by index.

Parameters
  • index (int) – Index of the sample.

  • part ({'train', 'validation', 'test'}, optional) – The data part. Default is 'train'.

  • out (array-like or tuple of (array-like or bool) or None) –

    Array(s) (or e.g. odl element(s)) to which the sample is written. A tuple should be passed, if the dataset returns two or more arrays per sample (i.e. pairs, …). If a tuple element is a bool, it has the following meaning:

    True

    Create a new array and return it.

    False

    Do not return this array, i.e. None is returned.

Returns

sample – E.g. for a pair dataset: (array, None) if out=(True, False).

Return type

[tuple of ] (array-like or None)

get_samples(key, part='train', out=None)[source]

Get samples by slice or range.

The default implementation calls get_sample() if the dataset implements it.

Parameters
  • key (slice or range) – Indexes of the samples.

  • part ({'train', 'validation', 'test'}, optional) – The data part. Default is 'train'.

  • out (array-like or tuple of (array-like or bool) or None) –

    Array(s) (or e.g. odl element(s)) to which the sample is written. The first dimension must match the number of samples requested. A tuple should be passed, if the dataset returns two or more arrays per sample (i.e. pairs, …). If a tuple element is a bool, it has the following meaning:

    True

    Create a new array and return it.

    False

    Do not return this array, i.e. None is returned.

Returns

samples – If the dataset has multiple arrays per sample, a tuple holding arrays is returned. E.g. for a pair dataset: (array, None) if out=(True, False). The samples are stacked in the first (additional) dimension of each array.

Return type

[tuple of ] (array-like or None)

supports_random_access()[source]

Whether random access seems to be supported.

If the object has the attribute self.random_access, its value is returned (this is the preferred way for subclasses to indicate whether they support random access). Otherwise, a simple duck-type check is performed which tries to get the first sample by random access.

Returns

supportsTrue if the dataset supports random access, otherwise False.

Return type

bool

class dival.datasets.dataset.ObservationGroundTruthPairDataset(ground_truth_gen, forward_op, post_processor=None, train_len=None, validation_len=None, test_len=None, domain=None, noise_type=None, noise_kwargs=None, noise_seeds=None)[source]

Bases: dival.datasets.dataset.Dataset

Dataset of pairs generated from a ground truth generator by applying a forward operator and noise.

NB: This dataset class does not allow for random access. Supporting random access would require to restore the same random generator state each time the same sample is being accessed if a fixed noise realization should be used for each sample.

__init__(ground_truth_gen, forward_op, post_processor=None, train_len=None, validation_len=None, test_len=None, domain=None, noise_type=None, noise_kwargs=None, noise_seeds=None)[source]
Parameters
  • ground_truth_gen (generator function) – Function returning a generator providing ground truth. Must accept a part parameter like Dataset.generator().

  • forward_op (odl operator) – Forward operator to apply on the ground truth.

  • post_processor (odl operator, optional) – Post-processor to apply on the result of the forward operator.

  • train_len (int, optional) – Number of training samples.

  • validation_len (int, optional) – Number of validation samples.

  • test_len (int, optional) – Number of test samples.

  • domain (odl space, optional) – Ground truth domain. If not specified, it is inferred from forward_op.

  • noise_type (str, optional) – Noise type. See NoiseOperator for the list of supported noise types.

  • noise_kwargs (dict, optional) – Keyword arguments passed to NoiseOperator.

  • noise_seeds (dict of int, optional) – Seeds to use for random noise generation. The part ('train', …) is the key to the dict. If a key is omitted or a value is None, no fixed seed is used for that part. By default, no fixed seeds are used.

generator(part='train')[source]

Yield data.

The default implementation calls get_sample() if the dataset implements it (i.e., supports random access).

Parameters

part ({'train', 'validation', 'test'}, optional) – Whether to yield train, validation or test data. Default is 'train'.

Yields

data (odl element or tuple of odl elements) – Sample of the dataset.

class dival.datasets.dataset.GroundTruthDataset(space=None)[source]

Bases: dival.datasets.dataset.Dataset

Ground truth dataset base class.

__init__(space=None)[source]
Parameters

space (odl.space.base_tensors.TensorSpace, optional) – The space of the samples. It is strongly recommended to set space in subclasses, as some functionality may depend on it.

create_pair_dataset(forward_op, post_processor=None, noise_type=None, noise_kwargs=None, noise_seeds=None)[source]

The parameters are a subset of those of ObservationGroundTruthPairDataset.__init__().