Source code for dival.datasets.angle_subset_dataset

import numpy as np
from odl import nonuniform_partition
from odl.tomo import Parallel2dGeometry, RayTransform

from dival.datasets.dataset import Dataset


[docs]def get_angle_subset_dataset(dataset, num_angles, **kwargs): """ Return a :class:`AngleSubsetDataset` with a reduced number of angles. The angles in the subset are equidistant if the original angles are. Parameters ---------- dataset : `Dataset` Basis CT dataset. The number of angles must be divisible by `num_angles`. num_angles : int Number of angles in the subset. Keyword arguments are passed to ``AngleSubsetDataset.__init__``. Raises ------ ValueError If the number of angles of `dataset` is not divisible by `num_angles`. Returns ------- angle_subset_dataset : :class:`AngleSubsetDataset` The dataset with the reduced number of angles. """ num_angles_total = dataset.get_shape()[0][0] angle_indices = range(0, num_angles_total, num_angles_total // num_angles) if len(angle_indices) != num_angles: raise ValueError( 'Number of angles {:d} is not divisible by requested number of ' 'angles {:d}'.format(num_angles_total, num_angles)) return AngleSubsetDataset(dataset, angle_indices, **kwargs)
[docs]class AngleSubsetDataset(Dataset): """ CT dataset that selects a subset of the angles of a basis CT dataset. """
[docs] def __init__(self, dataset, angle_indices, impl=None): """ Parameters ---------- dataset : `Dataset` Basis CT dataset. Requirements: - sample elements are ``(observation, ground_truth)`` - :meth:`get_ray_trafo` gives corresponding ray transform. angle_indices : array-like or slice Indices of the angles to use from the observations. impl : {``'skimage'``, ``'astra_cpu'``, ``'astra_cuda'``},\ optional Implementation passed to :class:`odl.tomo.RayTransform` to construct :attr:`ray_trafo`. """ self.dataset = dataset self.angle_indices = (angle_indices if isinstance(angle_indices, slice) else np.asarray(angle_indices)) self.train_len = self.dataset.get_len('train') self.validation_len = self.dataset.get_len('validation') self.test_len = self.dataset.get_len('test') self.random_access = self.dataset.supports_random_access() self.num_elements_per_sample = ( self.dataset.get_num_elements_per_sample()) orig_geometry = self.dataset.get_ray_trafo(impl=impl).geometry apart = nonuniform_partition( orig_geometry.angles[self.angle_indices]) self.geometry = Parallel2dGeometry( apart=apart, dpart=orig_geometry.det_partition) orig_shape = self.dataset.get_shape() self.shape = ((apart.shape[0], orig_shape[0][1]), orig_shape[1]) self.space = (None, self.dataset.space[1]) # preliminary, needed for # call to get_ray_trafo self.ray_trafo = self.get_ray_trafo(impl=impl) super().__init__(space=(self.ray_trafo.range, self.dataset.space[1]))
[docs] def get_ray_trafo(self, **kwargs): """ Return the ray transform that matches the subset of angles specified to the constructor via `angle_indices`. """ return RayTransform(self.space[1], self.geometry, **kwargs)
[docs] def generator(self, part='train'): for (obs, gt) in self.dataset.generator(part=part): yield (self.space[0].element(obs[self.angle_indices]), gt)
[docs] def get_sample(self, index, part='train', out=None): if out is None: out = (True, True) (out_obs, out_gt) = out out_basis = (out_obs is not False, out_gt) obs_basis, gt = self.dataset.get_sample(index, part=part, out=out_basis) if isinstance(out_obs, bool): obs = (self.space[0].element(obs_basis[self.angle_indices]) if out_obs else None) else: out_obs[:] = obs_basis[self.angle_indices] obs = out_obs return (obs, gt)
[docs] def get_samples(self, key, part='train', out=None): if out is None: out = (True, True) (out_obs, out_gt) = out out_basis = (out_obs is not False, out_gt) obs_arr_basis, gt_arr = self.dataset.get_samples(key, part=part, out=out_basis) if isinstance(out_obs, bool): obs_arr = obs_arr_basis[:, self.angle_indices] if out_obs else None else: out_obs[:] = obs_arr_basis[:, self.angle_indices] obs_arr = out_obs return (obs_arr, gt_arr)