# -*- coding: utf-8 -*-
"""Provides `LoDoPaBDataset`.
Provides simple access to the `LoDoPaB-CT dataset
<https://zenodo.org/record/3384092>`_ documented in a
`Data Descriptor article
<https://www.nature.com/articles/s41597-021-00893-z>`_.
"""
import os
from warnings import warn
from math import ceil
import numpy as np
import h5py
from zipfile import ZipFile
from tqdm import tqdm
from odl import uniform_discr
import odl.tomo
from dival.datasets.dataset import Dataset
from dival.config import CONFIG, set_config
from dival.util.constants import MU_MAX
from dival.util.zenodo_download import download_zenodo_record
from dival.util.input import input_yes_no
try:
DATA_PATH = CONFIG['lodopab_dataset']['data_path']
except Exception:
raise RuntimeError(
'Could not retrieve config value `lodopab_dataset/data_path`, '
'maybe the configuration (e.g. in ~/.dival/config.json) is corrupt.')
NUM_SAMPLES_PER_FILE = 128
PHOTONS_PER_PIXEL = 4096
ORIG_MIN_PHOTON_COUNT = 0.1
MIN_PT = [-0.13, -0.13]
MAX_PT = [0.13, 0.13]
LEN = {
'train': 35820,
'validation': 3522,
'test': 3553}
NUM_PATIENTS = {
'train': 632,
'validation': 60,
'test': 60}
PATIENT_ID_OFFSETS = {
'train': 0,
'validation': NUM_PATIENTS['train'],
'test': NUM_PATIENTS['train'] + NUM_PATIENTS['validation']}
[docs]
def download_lodopab():
global DATA_PATH
print('Before downloading, please make sure to have enough free disk '
'space (~150GB). After unpacking, 114.7GB will be used.')
print("path to store LoDoPaB-CT dataset (default '{}'):".format(DATA_PATH))
inp = input()
if inp:
DATA_PATH = inp
set_config('lodopab_dataset/data_path', DATA_PATH)
os.makedirs(DATA_PATH, exist_ok=True)
ZENODO_RECORD_ID = '3384092'
success = download_zenodo_record(ZENODO_RECORD_ID, DATA_PATH)
print('download of LoDoPaB-CT dataset {}'.format('successful' if success
else 'failed'))
if not success:
return False
file_list = ['observation_train.zip', 'ground_truth_train.zip',
'observation_validation.zip', 'ground_truth_validation.zip',
'observation_test.zip', 'ground_truth_test.zip']
print('unzipping zip files, this can take several minutes', flush=True)
for file in tqdm(file_list, desc='unzip'):
filename = os.path.join(DATA_PATH, file)
with ZipFile(filename, 'r') as f:
f.extractall(DATA_PATH)
os.remove(filename)
return True
[docs]
class LoDoPaBDataset(Dataset):
"""
The LoDoPaB-CT dataset, which is documented in the Data Descriptor article
`<https://www.nature.com/articles/s41597-021-00893-z>`_ and hosted on
`<https://zenodo.org/record/3384092>`_.
It is a simulated low dose CT dataset based on real reconstructions from
the `LIDC-IDRI
<https://wiki.cancerimagingarchive.net/display/Public/LIDC-IDRI>`_ dataset.
The dataset contains 42895 pairs of images and projection data.
For simulation, a ray transform with parallel beam geometry using 1000
angles and 513 detector pixels is used. Poisson noise corresponding to 4096
incident photons per pixel before attenuation is applied to the projection
data. The images have a size of 362x362 px.
An ODL ray transform that corresponds to the noiseless forward operator can
be obtained via the `get_ray_trafo` method of this dataset.
Additionally, the :attr:`ray_trafo` attribute holds a ray transform
instance, which is created during :meth:`__init__`.
*Note:* By default, the ``'astra_cuda'`` implementation backend is used,
which requires both astra and a CUDA-enabled GPU being available.
You can choose a different backend by passing ``impl='skimage'`` or
``impl='astra_cpu'``.
Further functionalities:
* converting the stored post-log observations to pre-log observations
on the fly (cf. `observation_model` parameter of :meth:`__init__`)
* sorting by patient ids (cf. ``sorted_by_patient`` parameter of
:meth:`__init__`)
* changing the zero photon count replacement value of ``0.1`` used for
pre-log observations (cf. ``min_photon_count`` parameter of
:meth:`__init__`)
Attributes
----------
space
``(space[0], space[1])``, where
``space[0]``
``odl.uniform_discr([0., -0.1838], [3.1416, 0.1838],
(1000, 513), dtype='float32')``
``space[1]``
``odl.uniform_discr(min_pt, max_pt, (362, 362),
dtype='float32'))``, with `min_pt` and `max_pt` parameters
passed to :meth:`__init__`
shape
``(362, 362)``
train_len
``35820``
validation_len
``3522``
test_len
``3553``
random_access
``True``
num_elements_per_sample
``2``
ray_trafo : :class:`odl.tomo.RayTransform`
Ray transform corresponding to the noiseless forward operator.
sorted_by_patient : bool
Whether the samples are sorted by patient id.
Default: ``False``.
rel_patient_ids : (dict of array) or `None`
Relative patient ids of the samples in the original non-sorted order
for each part, as returned by :meth:`LoDoPaBDataset.get_patient_ids`.
`None`, if the csv files are not found.
"""
[docs]
def __init__(self, min_pt=None, max_pt=None, observation_model='post-log',
min_photon_count=None, sorted_by_patient=False,
impl='astra_cuda'):
"""
Parameters
----------
min_pt : [float, float], optional
Minimum values of the lp space. Default: ``[-0.13, -0.13]``.
max_pt : [float, float], optional
Maximum values of the lp space. Default: ``[0.13, 0.13]``.
observation_model : {'post-log', 'pre-log'}, optional
The observation model to use.
The default is ``'post-log'``.
``'post-log'``
Observations are linearly related to the normalized ground
truth via the ray transform, ``obs = ray_trafo(gt) + noise``.
Note that the scaling of the observations matches the
normalized ground truth, i.e., they are divided by the linear
attenuation of 3071 HU.
``'pre-log'``
Observations are non-linearly related to the ground truth, as
given by the Beer-Lambert law.
The model is
``obs = exp(-ray_trafo(gt * MU(3071 HU))) + noise``,
where `MU(3071 HU)` is the factor, by which the ground truth
was normalized.
min_photon_count : float, optional
Replacement value for a simulated photon count of zero.
If ``observation_model == 'post-log'``, a value greater than zero
is required in order to avoid undefined values. The default is 0.1,
both for ``'post-log'`` and ``'pre-log'`` model.
sorted_by_patient : bool, optional
Whether to sort the samples by patient id.
Useful to resplit the dataset.
See also :meth:`get_indices_for_patient`.
Note that the slices of each patient are ordered randomly wrt.
the z-location in any case.
Default: ``False``.
impl : {``'skimage'``, ``'astra_cpu'``, ``'astra_cuda'``},\
optional
Implementation passed to :class:`odl.tomo.RayTransform` to
construct :attr:`ray_trafo`.
"""
global DATA_PATH
NUM_ANGLES = 1000
NUM_DET_PIXELS = 513
self.shape = ((NUM_ANGLES, NUM_DET_PIXELS), (362, 362))
self.num_elements_per_sample = 2
if min_pt is None:
min_pt = MIN_PT
if max_pt is None:
max_pt = MAX_PT
domain = uniform_discr(min_pt, max_pt, self.shape[1], dtype=np.float32)
if observation_model == 'post-log':
self.post_log = True
elif observation_model == 'pre-log':
self.post_log = False
else:
raise ValueError("`observation_model` must be 'post-log' or "
"'pre-log', not '{}'".format(observation_model))
if min_photon_count is None or min_photon_count <= 1.:
self.min_photon_count = min_photon_count
else:
self.min_photon_count = 1.
warn('`min_photon_count` changed from {} to 1.'.format(
min_photon_count))
self.sorted_by_patient = sorted_by_patient
self.train_len = LEN['train']
self.validation_len = LEN['validation']
self.test_len = LEN['test']
self.random_access = True
while not LoDoPaBDataset.check_for_lodopab():
print('The LoDoPaB-CT dataset could not be found under the '
"configured path '{}'.".format(
CONFIG['lodopab_dataset']['data_path']))
print('Do you want to download it now? (y: download, n: input '
'other path)')
download = input_yes_no()
if download:
success = download_lodopab()
if not success:
raise RuntimeError('lodopab dataset not available, '
'download failed')
else:
print('Path to LoDoPaB dataset:')
DATA_PATH = input()
set_config('lodopab_dataset/data_path', DATA_PATH)
self.rel_patient_ids = None
try:
self.rel_patient_ids = LoDoPaBDataset.get_patient_ids()
except OSError as e:
if self.sorted_by_patient:
raise RuntimeError(
'Can not load patient ids, required for sorting. '
'OSError: {}'.format(e))
warn(
'Can not load patient ids (OSError: {}). '
'Therefore sorting is not possible, so please keep the '
'attribute `sorted_by_patient = False` for the LoDoPaBDataset.'
.format(e))
if self.rel_patient_ids is not None:
self._idx_sorted_by_patient = (
LoDoPaBDataset.get_idx_sorted_by_patient(
self.rel_patient_ids))
self.geometry = odl.tomo.parallel_beam_geometry(
domain, num_angles=NUM_ANGLES, det_shape=(NUM_DET_PIXELS,))
range_ = uniform_discr(self.geometry.partition.min_pt,
self.geometry.partition.max_pt,
self.shape[0], dtype=np.float32)
super().__init__(space=(range_, domain))
self.ray_trafo = self.get_ray_trafo(impl=impl)
def __get_observation_trafo(self, num_samples=1):
if (self.min_photon_count is None or
self.min_photon_count == ORIG_MIN_PHOTON_COUNT):
if self.post_log:
def observation_trafo(out):
pass
else:
def observation_trafo(obs):
obs *= MU_MAX
np.exp(-obs, out=obs)
else:
shape = (self.shape[0] if num_samples == 1 else
(num_samples,) + self.shape[0])
mask = np.empty(shape, dtype=np.bool)
thres0 = 0.5 * (
-np.log(ORIG_MIN_PHOTON_COUNT/PHOTONS_PER_PIXEL)
- np.log(1/PHOTONS_PER_PIXEL)) / MU_MAX
if self.post_log:
def observation_trafo(obs):
np.greater_equal(obs, thres0, out=mask)
obs[mask] = -np.log(self.min_photon_count
/ PHOTONS_PER_PIXEL) / MU_MAX
else:
def observation_trafo(obs):
np.greater_equal(obs, thres0, out=mask)
obs *= MU_MAX
np.exp(-obs, out=obs)
obs[mask] = self.min_photon_count/PHOTONS_PER_PIXEL
return observation_trafo
[docs]
def generator(self, part='train'):
"""Yield pairs of low dose observations and (virtual) ground truth.
Parameters
----------
part : {``'train'``, ``'validation'``, ``'test'``}, optional
The data part. Default is ``'train'``.
Yields
------
(observation, ground_truth)
`observation` : odl element with shape ``(1000, 513)``
The values depend on the
`observation_model` and `min_photon_count` parameters that were
passed to :meth:`__init__`.
`ground_truth` : odl element with shape ``(362, 362)``
The values lie in the range ``[0., 1.]``.
"""
if self.sorted_by_patient:
# fall back to default implementation
yield from super().generator(part=part)
return
num_files = ceil(self.get_len(part) / NUM_SAMPLES_PER_FILE)
observation_trafo = self.__get_observation_trafo()
for i in range(num_files):
with h5py.File(
os.path.join(DATA_PATH, 'ground_truth_{}_{:03d}.hdf5'
.format(part, i)), 'r') as file:
ground_truth_data = file['data'][:]
with h5py.File(
os.path.join(DATA_PATH, 'observation_{}_{:03d}.hdf5'
.format(part, i)), 'r') as file:
observation_data = file['data'][:]
for gt_arr, obs_arr in zip(ground_truth_data, observation_data):
ground_truth = self.space[1].element(gt_arr)
observation = self.space[0].element(obs_arr)
observation_trafo(observation)
yield (observation, ground_truth)
[docs]
def get_ray_trafo(self, **kwargs):
"""
Return the ray transform that is a noiseless version of the forward
operator.
Parameters
----------
impl : {``'skimage'``, ``'astra_cpu'``, ``'astra_cuda'``}, optional
The backend implementation passed to
:class:`odl.tomo.RayTransform`.
Returns
-------
ray_trafo : odl operator
The ray transform that corresponds to the noiseless map from
362 x 362 images to the ``-log`` of their projections (sinograms).
"""
return odl.tomo.RayTransform(self.space[1], self.geometry, **kwargs)
[docs]
def get_sample(self, index, part='train', out=None):
"""
Get single sample of the dataset.
Returns a pair of (virtual) ground truth and its low dose observation,
of which either part can be left out by option.
Parameters
----------
index : int
The index into the dataset part.
part : {``'train'``, ``'validation'``, ``'test'``}, optional
The data part. Default is ``'train'``.
out : tuple of array-likes or bools, optional
``out==(out_observation, out_ground_truth)``
out_observation : array-like or bool
Shape ``(1000, 513)``.
If an odl element or array is passed, the observation is
written to it.
If ``True``, a new odl element holding the observation is
created (the default).
If ``False``, no observation is returned.
out_ground_truth : array-like or bool
Shape ``(362, 362)``.
If an odl element or array is passed, the ground truth is
written to it.
If ``True``, a new odl element holding the ground truth is
created (the default).
If ``False``, no ground truth is returned.
Returns
-------
``(observation, ground_truth)``
observation : odl element or :class:`np.ndarray` or `None`
Depending on the value of ``out_observation`` (see parameter
`out`), a newly created odl element, ``out_observation`` or
`None` is returned.
The observation values depend on the `observation_model` and
`min_photon_count` parameters that were given to the
constructor.
ground_truth : odl element or :class:`np.ndarray` or `None`
Depending on the value of ``out_ground_truth`` (see parameter
`out`), a newly created odl element, ``out_ground_truth`` or
`None` is returned.
The values lie in the range ``[0., 1.]``.
"""
len_part = self.get_len(part)
if index >= len_part or index < -len_part:
raise IndexError("index {} out of bounds for part '{}' ({:d})"
.format(index, part, len_part))
if index < 0:
index += len_part
if out is None:
out = (True, True)
(out_observation, out_ground_truth) = out
if self.sorted_by_patient:
index = self._idx_sorted_by_patient[part][index]
file_index = index // NUM_SAMPLES_PER_FILE
index_in_file = index % NUM_SAMPLES_PER_FILE
if isinstance(out_observation, bool):
obs = self.space[0].zero() if out_observation else None
else:
obs = out_observation
if isinstance(out_ground_truth, bool):
gt = self.space[1].zero() if out_ground_truth else None
else:
gt = out_ground_truth
if obs is not None:
with h5py.File(
os.path.join(DATA_PATH,
'observation_{}_{:03d}.hdf5'
.format(part, file_index)), 'r') as file:
file['data'].read_direct(np.asarray(obs)[np.newaxis],
np.s_[index_in_file:index_in_file+1],
np.s_[0:1])
observation_trafo = self.__get_observation_trafo()
observation_trafo(obs)
if gt is not None:
with h5py.File(
os.path.join(DATA_PATH,
'ground_truth_{}_{:03d}.hdf5'
.format(part, file_index)), 'r') as file:
file['data'].read_direct(np.asarray(gt)[np.newaxis],
np.s_[index_in_file:index_in_file+1],
np.s_[0:1])
return (obs, gt)
[docs]
def get_samples(self, key, part='train', out=None):
"""
Get slice of the dataset.
Returns a pair of (virtual) ground truth data and its low dose
observation data, of which either part can be left out by option.
Parameters
----------
key : slice or range
The indices into the dataset part.
part : {``'train'``, ``'validation'``, ``'test'``}, optional
The data part. Default is ``'train'``.
out : tuple of arrays or bools, optional
``out==(out_observation, out_ground_truth)``
out_observation : :class:`np.ndarray` or bool
If an array is passed, the observation data is written to it.
If ``True``, a new array holding the observation data is
created (the default).
If ``False``, no observation data is returned.
out_ground_truth : :class:`np.ndarray` or bool
If an array is passed, the ground truth data is written to it.
If ``True``, a new array holding the ground truth data is
created (the default).
If ``False``, no ground truth data is returned.
Returns
-------
``(observation, ground_truth)``
observation : :class:`np.ndarray` or `None`
Shape ``(samples, 1000, 513)``.
Depending on the value of ``out_observation`` (see parameter
`out`), a newly created array, ``out_observation`` or `None`
is returned.
The observation values depend on the `observation_model` and
`min_photon_count` parameters that were given to the
constructor.
ground_truth : :class:`np.ndarray` or `None`
Shape ``(samples, 362, 362)``.
Depending on the value of ``out_ground_truth`` (see parameter
`out`), a newly created array, ``out_ground_truth`` or `None`
is returned.
The values lie in the range ``[0., 1.]``.
"""
if self.sorted_by_patient:
# fall back to default implementation
return super().get_samples(key, part=part, out=out)
len_part = self.get_len(part)
if isinstance(key, slice):
key_start = (0 if key.start is None else
(key.start if key.start >= 0 else
max(0, len_part+key.start)))
key_stop = (len_part if key.stop is None else
(key.stop if key.stop >= 0 else
max(0, len_part+key.stop)))
range_ = range(key_start, key_stop, key.step or 1)
elif isinstance(key, range):
range_ = key
else:
raise TypeError('`key` expected to have type `slice` or `range`')
if range_.step < 0:
raise ValueError('key {} invalid, negative steps are not '
'implemented yet'.format(key))
if range_[-1] >= len_part:
raise IndexError("key {} out of bounds for part '{}' ({:d})"
.format(key, part, len_part))
range_files = range(range_[0] // NUM_SAMPLES_PER_FILE,
range_[-1] // NUM_SAMPLES_PER_FILE + 1)
if out is None:
out = (True, True)
(out_observation, out_ground_truth) = out
# compute slice objects
slices_files = []
slices_data = []
data_count = 0
for i in range_files:
if i == range_files.start:
start = range_.start % NUM_SAMPLES_PER_FILE
else:
start = (range_.start - i*NUM_SAMPLES_PER_FILE) % range_.step
if i == range_files[-1]:
stop = range_[-1] % NUM_SAMPLES_PER_FILE + 1
else:
__next_start = ((range_.start - (i+1)*NUM_SAMPLES_PER_FILE)
% range_.step)
stop = (__next_start - range_.step) % NUM_SAMPLES_PER_FILE + 1
s = slice(start, stop, range_.step)
slices_files.append(s)
len_slice = ceil((s.stop-s.start) / s.step)
slices_data.append(slice(data_count, data_count+len_slice))
data_count += len_slice
# read data
if isinstance(out_observation, bool):
obs_arr = (np.empty((len(range_),) + self.shape[0],
dtype=np.float32) if out_observation else None)
else:
obs_arr = out_observation
if isinstance(out_ground_truth, bool):
gt_arr = (np.empty((len(range_),) + self.shape[1],
dtype=np.float32) if out_ground_truth else None)
else:
gt_arr = out_ground_truth
if obs_arr is not None:
for i, slc_f, slc_d in zip(range_files, slices_files, slices_data):
with h5py.File(
os.path.join(DATA_PATH,
'observation_{}_{:03d}.hdf5'
.format(part, i)), 'r') as file:
file['data'].read_direct(obs_arr, slc_f, slc_d)
observation_trafo = self.__get_observation_trafo(
num_samples=len(obs_arr))
observation_trafo(obs_arr)
if gt_arr is not None:
for i, slc_f, slc_d in zip(range_files, slices_files, slices_data):
with h5py.File(
os.path.join(DATA_PATH,
'ground_truth_{}_{:03d}.hdf5'
.format(part, i)), 'r') as file:
file['data'].read_direct(gt_arr, slc_f, slc_d)
return (obs_arr, gt_arr)
[docs]
def get_indices_for_patient(self, rel_patient_id, part='train'):
"""
Return the indices of the samples from one patient.
If ``self.sorted_by_patient`` is ``True``, the indices will be
subsequent.
Parameters
----------
rel_patient_id : int
Patient id, relative to the part.
part : {``'train'``, ``'validation'``, ``'test'``}, optional
Whether to return the number of train, validation or test patients.
Default is ``'train'``.
Returns
-------
indices : array
The indices of the samples from the patient.
"""
if self.sorted_by_patient:
num_samples_by_patient = np.bincount(self.rel_patient_ids[part])
first_sample = np.sum(num_samples_by_patient[:rel_patient_id])
indices = np.array(range(
first_sample,
first_sample + num_samples_by_patient[rel_patient_id]))
else:
indices = np.nonzero(
self.rel_patient_ids[part] == rel_patient_id)[0]
return indices
[docs]
@staticmethod
def check_for_lodopab():
"""Fast check whether first and last file of each dataset part exist
under the configured data path.
Returns
-------
exists : bool
Whether LoDoPaB seems to exist.
"""
for part in ['train', 'validation', 'test']:
first_file = os.path.join(
DATA_PATH, 'observation_{}_000.hdf5'.format(part))
last_file = os.path.join(
DATA_PATH, 'observation_{}_{:03d}.hdf5'.format(
part, ceil(LEN[part] / NUM_SAMPLES_PER_FILE) - 1))
if not (os.path.exists(first_file) and os.path.exists(last_file)):
return False
return True
[docs]
@staticmethod
def get_num_patients(part='train'):
"""
Return the number of patients in a dataset part.
Parameters
----------
part : {``'train'``, ``'validation'``, ``'test'``}, optional
Whether to return the number of train, validation or test patients.
Default is ``'train'``.
"""
return NUM_PATIENTS[part]
@staticmethod
def _abs_to_rel_patient_id(abs_patient_id, part):
return abs_patient_id - PATIENT_ID_OFFSETS[part]
@staticmethod
def _rel_to_abs_patient_id(rel_patient_id, part):
return rel_patient_id + PATIENT_ID_OFFSETS[part]
[docs]
@staticmethod
def get_patient_ids(relative=True):
"""
Return the (relative) patient id for all samples of all dataset parts.
Parameters
----------
relative : bool, optional
Whether to use ids relative to the dataset part.
The csv files store absolute indices, where
"train_ids < validation_ids < test_ids".
If ``False``, these absolute indices are returned.
If ``True``, the smallest absolute id of the part is subtracted,
giving zero-based (relative) patient ids.
Default: ``True``
Returns
-------
ids : dict of array
For each part: an array with the (relative) patient ids for all
samples (length: number of samples in the corresponding part).
Raises
------
OSError
An `OSError` is raised if one of the csv files containing the
patient ids is missing in the configured data path.
"""
ids = {}
for part in ['train', 'validation', 'test']:
ids[part] = np.loadtxt(
os.path.join(DATA_PATH,
'patient_ids_rand_{}.csv'.format(part)),
dtype=int)
if relative:
ids[part] = LoDoPaBDataset._abs_to_rel_patient_id(ids[part],
part)
return ids
[docs]
@staticmethod
def get_idx_sorted_by_patient(ids=None):
"""
Return indices that allow access to each dataset part in patient id
order.
*Note:* in most cases this method should not be called directly. Rather
specify ``sorted_by_patient=True`` to the constructor if applicable.
A plausible use case of this method, however, is to access existing
cache files that were created with ``sorted_by_patient=False``.
In this case, the dataset should be constructed with
``sorted_by_patient=False``, wrapped by a :class:`CachedDataset`
and then reordered with :class:`ReorderedDataset` using the indices
returned by this method.
Parameters
----------
ids : dict of array-like, optional
Patient ids as returned by :meth:`get_patient_ids`. It is not
relevant to this function whether they are relative.
Returns
-------
idx : dict of array
Indices that allow access to each dataset part in patient id order.
Each array value is an index into the samples in original order
(as stored in the HDF5 files).
I.e.: By iterating the samples with index ``idx[part][i]`` for
``i = 0, 1, 2, ...`` one first obtains all samples from one
patient, then continues with the samples of the second patient, and
so on.
Raises
------
OSError
An `OSError` is raised if ``ids is None`` and one of the csv files
containing the patient ids is missing in the configured data path.
"""
if ids is None:
ids = LoDoPaBDataset.get_patient_ids()
idx = {}
for part in ['train', 'validation', 'test']:
idx[part] = np.argsort(ids[part], kind='stable')
return idx