dival.datasets.dataset module
Provides the dataset base classes.
- class dival.datasets.dataset.Dataset(space=None)[source]
Bases:
object
Dataset base class.
Subclasses must either implement
generator()
or provide random access by implementingget_sample()
andget_samples()
(which then should be indicated by setting the attributerandom_access = True
).- space
The spaces of the elements of samples as a tuple. If only one element per sample is provided, this attribute is the space of the element (i.e., no tuple). It is strongly recommended to set this attribute in subclasses, as some functionality may depend on it.
- Type:
[tuple of ]
odl.space.base_tensors.TensorSpace
or None
- shape
The shapes of the elements of samples as a tuple of tuple of int. If only one element per sample is provided, this attribute is the shape of the element (i.e., not a tuple of tuple of int, but a tuple of int).
- Type:
[tuple of ] tuple of int, optional
- train_len
Number of training samples.
- Type:
int, optional
- validation_len
Number of validation samples.
- Type:
int, optional
- test_len
Number of test samples.
- Type:
int, optional
- random_access
Whether the dataset supports random access via
self.get_sample
andself.get_samples
. Setting this attribute is the preferred way for subclasses to indicate whether they support random access.- Type:
bool, optional
- num_elements_per_sample
Number of elements per sample. E.g. 1 for a ground truth dataset or 2 for a dataset of pairs of observation and ground truth.
- Type:
int, optional
- standard_dataset_name
Datasets returned by get_standard_dataset have this attribute giving its name.
- Type:
str, optional
- __init__(space=None)[source]
The attributes that potentially should be set by the subclass are:
space
(can also be set by argument),shape
,train_len
,validation_len
,test_len
,random_access
andnum_elements_per_sample
.- Parameters:
space ([tuple of ]
odl.space.base_tensors.TensorSpace
, optional) – The spaces of the elements of samples as a tuple. If only one element per sample is provided, this attribute is the space of the element (i.e., no tuple). It is strongly recommended to set space in subclasses, as some functionality may depend on it.
- generator(part='train')[source]
Yield data.
The default implementation calls
get_sample()
if the dataset implements it (i.e., supports random access).- Parameters:
part ({
'train'
,'validation'
,'test'
}, optional) – Whether to yield train, validation or test data. Default is'train'
.- Yields:
data (odl element or tuple of odl elements) – Sample of the dataset.
- get_len(part='train')[source]
Return the number of elements the generator will yield.
- Parameters:
part ({
'train'
,'validation'
,'test'
}, optional) – Whether to return the number of train, validation or test elements. Default is'train'
.
- get_shape()[source]
Return the shape of each element.
Returns
shape
if it is set. Otherwise, it is inferred fromspace
(which is strongly recommended to be set in every subclass). If alsospace
is not set, aNotImplementedError
is raised.- Returns:
shape
- Return type:
[tuple of ] tuple
- get_num_elements_per_sample()[source]
Return number of elements per sample.
Returns
num_elements_per_sample
if it is set. Otherwise, it is inferred fromspace
(which is strongly recommended to be set in every subclass). If alsospace
is not set, aNotImplementedError
is raised.- Returns:
num_elements_per_sample
- Return type:
int
- get_data_pairs(part='train', n=None)[source]
Return first samples from data part as
DataPairs
object.Only supports datasets with two elements per sample.``
- Parameters:
part ({
'train'
,'validation'
,'test'
}, optional) – The data part. Default is'train'
.n (int, optional) – Number of pairs (from beginning). If None, all available data is used (the default).
- get_data_pairs_per_index(part='train', index=None)[source]
Return specific samples from data part as
DataPairs
object.Only supports datasets with two elements per sample.
For datasets not supporting random access, samples are extracted from
generator()
, which can be computationally expensive.- Parameters:
part ({
'train'
,'validation'
,'test'
}, optional) – The data part. Default is'train'
.index (int or list of int, optional) – Indices of the samples in the data part. Default is
'[0]'
.
- create_torch_dataset(part='train', reshape=None, transform=None)[source]
Create a torch dataset wrapper for one part of this dataset.
If
supports_random_access()
returnsFalse
, a subclass of oftorch.utils.data.IterableDataset
is returned that fetches samples viagenerator()
. Note: When using torch’s DataLoader with multiple workers you might want to individually configure the datasets for each worker, see the PyTorch docs on IterableDataset. For this purpose it can be useful to modify the wrapped dival dataset inworker_init_fn()
, which can be accessed there viatorch.utils.data.get_worker_info().dataset.dataset
.If
supports_random_access()
returns True, a subclass of oftorch.utils.data.Dataset
is returned that retrieves samples usingget_sample()
.- Parameters:
part ({
'train'
,'validation'
,'test'
}, optional) – The data part. Default is'train'
.reshape (tuple of (tuple or None), optional) – Shapes to which the elements of each sample will be reshaped. If None is passed for an element, no reshape is applied.
transform (callable, optional) – Transform to be applied on each sample, useful for augmentation. Default: None, i.e. no transform.
- Returns:
dataset – The torch dataset wrapping this dataset. The wrapped dival dataset is assigned to the attribute
dataset.dataset
.- Return type:
torch.utils.data.Dataset
ortorch.utils.data.IterableDataset
- create_keras_generator(part='train', batch_size=1, shuffle=True, reshape=None)[source]
Create a keras data generator wrapper for one part of this dataset.
If
supports_random_access()
returnsFalse
, a generator wrappinggenerator()
is returned. In this case no shuffling is performed regardless of the passed shuffle parameter. Also, parallel data loading (with multiple workers) is not applicable.If
supports_random_access()
returns True, atf.keras.utils.Sequence
is returned, which is implemented usingget_sample()
. For datasets that support parallel calls toget_sample()
, the returned data generator (sequence) can be used by multiple workers.- Parameters:
part ({
'train'
,'validation'
,'test'
}, optional) – The data part. Default is'train'
.batch_size (int, optional) – Batch size. Default is 1.
shuffle (bool, optional) – Whether to shuffle samples each epoch. This option has no effect if
supports_random_access()
returnsFalse
, since in that case samples are fetched directly fromgenerator()
. The default is True.reshape (tuple of (tuple or None), optional) – Shapes to which the elements of each sample will be reshaped. If None is passed for an element, no reshape is applied.
- get_sample(index, part='train', out=None)[source]
Get single sample by index.
- Parameters:
index (int) – Index of the sample.
part ({
'train'
,'validation'
,'test'
}, optional) – The data part. Default is'train'
.out (array-like or tuple of (array-like or bool) or None) –
Array(s) (or e.g. odl element(s)) to which the sample is written. A tuple should be passed, if the dataset returns two or more arrays per sample (i.e. pairs, …). If a tuple element is a bool, it has the following meaning:
True
Create a new array and return it.
False
Do not return this array, i.e. None is returned.
- Returns:
sample – E.g. for a pair dataset:
(array, None)
ifout=(True, False)
.- Return type:
[tuple of ] (array-like or None)
- get_samples(key, part='train', out=None)[source]
Get samples by slice or range.
The default implementation calls
get_sample()
if the dataset implements it.- Parameters:
key (slice or range) – Indexes of the samples.
part ({
'train'
,'validation'
,'test'
}, optional) – The data part. Default is'train'
.out (array-like or tuple of (array-like or bool) or None) –
Array(s) (or e.g. odl element(s)) to which the sample is written. The first dimension must match the number of samples requested. A tuple should be passed, if the dataset returns two or more arrays per sample (i.e. pairs, …). If a tuple element is a bool, it has the following meaning:
True
Create a new array and return it.
False
Do not return this array, i.e. None is returned.
- Returns:
samples – If the dataset has multiple arrays per sample, a tuple holding arrays is returned. E.g. for a pair dataset:
(array, None)
ifout=(True, False)
. The samples are stacked in the first (additional) dimension of each array.- Return type:
[tuple of ] (array-like or None)
- supports_random_access()[source]
Whether random access seems to be supported.
If the object has the attribute self.random_access, its value is returned (this is the preferred way for subclasses to indicate whether they support random access). Otherwise, a simple duck-type check is performed which tries to get the first sample by random access.
- Returns:
supports –
True
if the dataset supports random access, otherwiseFalse
.- Return type:
bool
- class dival.datasets.dataset.ObservationGroundTruthPairDataset(ground_truth_gen, forward_op, post_processor=None, train_len=None, validation_len=None, test_len=None, domain=None, noise_type=None, noise_kwargs=None, noise_seeds=None)[source]
Bases:
Dataset
Dataset of pairs generated from a ground truth generator by applying a forward operator and noise.
NB: This dataset class does not allow for random access. Supporting random access would require to restore the same random generator state each time the same sample is being accessed if a fixed noise realization should be used for each sample.
- __init__(ground_truth_gen, forward_op, post_processor=None, train_len=None, validation_len=None, test_len=None, domain=None, noise_type=None, noise_kwargs=None, noise_seeds=None)[source]
- Parameters:
ground_truth_gen (generator function) – Function returning a generator providing ground truth. Must accept a part parameter like
Dataset.generator()
.forward_op (odl operator) – Forward operator to apply on the ground truth.
post_processor (odl operator, optional) – Post-processor to apply on the result of the forward operator.
train_len (int, optional) – Number of training samples.
validation_len (int, optional) – Number of validation samples.
test_len (int, optional) – Number of test samples.
domain (odl space, optional) – Ground truth domain. If not specified, it is inferred from forward_op.
noise_type (str, optional) – Noise type. See
NoiseOperator
for the list of supported noise types.noise_kwargs (dict, optional) – Keyword arguments passed to
NoiseOperator
.noise_seeds (dict of int, optional) – Seeds to use for random noise generation. The part (
'train'
, …) is the key to the dict. If a key is omitted or a value is None, no fixed seed is used for that part. By default, no fixed seeds are used.
- generator(part='train')[source]
Yield data.
The default implementation calls
get_sample()
if the dataset implements it (i.e., supports random access).- Parameters:
part ({
'train'
,'validation'
,'test'
}, optional) – Whether to yield train, validation or test data. Default is'train'
.- Yields:
data (odl element or tuple of odl elements) – Sample of the dataset.
- class dival.datasets.dataset.GroundTruthDataset(space=None)[source]
Bases:
Dataset
Ground truth dataset base class.
- __init__(space=None)[source]
- Parameters:
space (
odl.space.base_tensors.TensorSpace
, optional) – The space of the samples. It is strongly recommended to set space in subclasses, as some functionality may depend on it.
- create_pair_dataset(forward_op, post_processor=None, noise_type=None, noise_kwargs=None, noise_seeds=None)[source]
The parameters are a subset of those of
ObservationGroundTruthPairDataset.__init__()
.