Source code for dival.datasets.reordered_dataset

from dival.datasets import Dataset


[docs]class ReorderedDataset(Dataset): """ Dataset that reorders the samples of another dataset by specified index arrays for each part. """
[docs] def __init__(self, dataset, idx): """ Parameters ---------- dataset : `Dataset` Dataset to take the samples from. Must support random access. idx : dict of array-like Indices into the original dataset for each part. Each array-like must have (at least) the same length as the part. """ assert dataset.supports_random_access() self.dataset = dataset self.idx = idx self.train_len = self.dataset.get_len('train') self.validation_len = self.dataset.get_len('validation') self.test_len = self.dataset.get_len('test') self.random_access = True self.num_elements_per_sample = ( self.dataset.get_num_elements_per_sample()) super().__init__(space=self.dataset.space) self.shape = self.dataset.get_shape()
# use default implementation of generator
[docs] def get_sample(self, index, part='train', out=None): sample = self.dataset.get_sample( self.idx[part][index], part=part, out=out) return sample
# use default implementation of get_samples, which calls get_sample # (seems close to optimal as the indices do not follow a simple pattern)