"""
Provides utilities related to PyTorch.
The classes and functions
:class:`TorchRayTrafoParallel2DModule`
:class:`TorchRayTrafoParallel2DAdjointModule`
:func:`get_torch_ray_trafo_parallel_2d`
:func:`get_torch_ray_trafo_parallel_2d_adjoint`.
in this module rely on the
`tomosipo <https://github.com/ahendriksen/tomosipo>`_ library and experimental
astra features available in version 1.9.9.dev4 using CUDA.
In order to instantiate or call these classes and functions, all of these
requirements need to be fulfilled, otherwise an :class:`ImportError` is raised.
"""
import numpy as np
import torch
try:
import tomosipo as ts
except ImportError:
TOMOSIPO_AVAILABLE = False
MISSING_TOMOSIPO_MESSAGE = (
'Missing optional dependency \'tomosipo\'. The latest development '
'version can be installed via '
'`pip install git+https://github.com/ahendriksen/tomosipo@develop`')
else:
TOMOSIPO_AVAILABLE = True
from tomosipo.odl import (
from_odl, parallel_2d_to_3d_geometry, discretized_space_2d_to_3d)
from tomosipo.torch_support import to_autograd
from odl.tomo.backends.astra_cuda import astra_cuda_bp_scaling_factor
try:
import astra
except ImportError:
ASTRA_AVAILABLE = False
else:
ASTRA_AVAILABLE = True
[docs]
class RandomAccessTorchDataset(torch.utils.data.Dataset):
[docs]
def __init__(self, dataset, part, reshape=None,
transform=None):
self.dataset = dataset
self.part = part
self.reshape = reshape or (
(None,) * self.dataset.get_num_elements_per_sample())
self.transform = transform
def __len__(self):
return self.dataset.get_len(self.part)
def __getitem__(self, idx):
arrays = self.dataset.get_sample(idx, part=self.part)
mult_elem = isinstance(arrays, tuple)
if not mult_elem:
arrays = (arrays,)
tensors = []
for arr, s in zip(arrays, self.reshape):
t = torch.from_numpy(np.asarray(arr))
if s is not None:
t = t.view(*s)
tensors.append(t)
sample = tuple(tensors) if mult_elem else tensors[0]
if self.transform is not None:
sample = self.transform(sample)
return sample
[docs]
class GeneratorTorchDataset(torch.utils.data.IterableDataset):
[docs]
def __init__(self, dataset, part, reshape=None,
transform=None):
self.part = part
self.dataset = dataset
self.reshape = reshape or (
(None,) * dataset.get_num_elements_per_sample())
self.transform = transform
def __len__(self):
return self.dataset.get_len(self.part)
def __iter__(self):
return self.generate()
[docs]
def generate(self):
for arrays in self.dataset.generator(self.part):
mult_elem = isinstance(arrays, tuple)
if not mult_elem:
arrays = (arrays,)
tensors = []
for arr, s in zip(arrays, self.reshape):
t = torch.from_numpy(np.asarray(arr))
if s is not None:
t = t.view(*s)
tensors.append(t)
sample = tuple(tensors) if mult_elem else tensors[0]
if self.transform is not None:
sample = self.transform(sample)
yield sample
[docs]
class TorchRayTrafoParallel2DModule(torch.nn.Module):
"""
Torch module applying a 2D parallel-beam ray transform using tomosipo that
calls the direct forward projection routine of astra, which avoids copying
between GPU and CPU (available in 1.9.9.dev4).
All 2D transforms are computed using a single 3D transform.
To this end the used tomosipo operator is renewed in :meth:`forward`
everytime the product of batch and channel dimensions of the current batch
differs compared to the previous batch, or compared to the value of
`init_z_shape` specified to :meth:`init` for the first batch.
"""
[docs]
def __init__(self, ray_trafo, init_z_shape=1):
"""
Parameters
----------
ray_trafo : :class:`odl.tomo.RayTransform`
Ray transform
init_z_shape : int, optional
Initial guess for the number of 2D transforms per batch, i.e. the
product of batch and channel dimensions.
"""
if not TOMOSIPO_AVAILABLE:
raise ImportError(MISSING_TOMOSIPO_MESSAGE)
if not ASTRA_AVAILABLE:
raise RuntimeError('Astra is not available.')
if not astra.use_cuda():
raise RuntimeError('Astra is not able to use CUDA.')
super().__init__()
self.ray_trafo = ray_trafo
self._construct_operator(init_z_shape)
def _construct_operator(self, z_shape):
self.torch_ray_trafo = (
get_torch_ray_trafo_parallel_2d(self.ray_trafo, z_shape=z_shape))
self._z_shape = z_shape
[docs]
def forward(self, x):
shape_orig = x.shape
z_shape = shape_orig[0] * shape_orig[1]
if self._z_shape != z_shape:
self._construct_operator(z_shape)
x = x.view(1, z_shape, *x.shape[2:])
x = self.torch_ray_trafo(x)
x = x.view(*shape_orig[:2], *x.shape[2:])
return x
[docs]
class TorchRayTrafoParallel2DAdjointModule(torch.nn.Module):
"""
Torch module applying the adjoint of a 2D parallel-beam ray transform
using tomosipo that calls the direct backward projection routine of astra,
which avoids copying between GPU and CPU (available in 1.9.9.dev4).
All 2D transforms are computed using a single 3D transform.
To this end the used tomosipo operator is renewed in :meth:`forward`
everytime the product of batch and channel dimensions of the current batch
differs compared to the previous batch, or compared to the value of
`init_z_shape` specified to :meth:`init` for the first batch.
"""
[docs]
def __init__(self, ray_trafo, init_z_shape=1):
"""
Parameters
----------
ray_trafo : :class:`odl.tomo.RayTransform`
Ray transform
init_z_shape : int, optional
Initial guess for the number of 2D transforms per batch, i.e. the
product of batch and channel dimensions.
"""
if not TOMOSIPO_AVAILABLE:
raise ImportError(MISSING_TOMOSIPO_MESSAGE)
if not ASTRA_AVAILABLE:
raise RuntimeError('Astra is not available.')
if not astra.use_cuda():
raise RuntimeError('Astra is not able to use CUDA.')
super().__init__()
self.ray_trafo = ray_trafo
self._construct_operator(init_z_shape)
def _construct_operator(self, z_shape):
self.torch_ray_trafo_adjoint = (
get_torch_ray_trafo_parallel_2d_adjoint(self.ray_trafo,
z_shape=z_shape))
self._z_shape = z_shape
[docs]
def forward(self, x):
shape_orig = x.shape
z_shape = shape_orig[0] * shape_orig[1]
if self._z_shape != z_shape:
self._construct_operator(z_shape)
x = x.view(1, z_shape, *x.shape[2:])
x = self.torch_ray_trafo_adjoint(x)
x = x.view(*shape_orig[:2], *x.shape[2:])
return x
[docs]
def get_torch_ray_trafo_parallel_2d(ray_trafo, z_shape=1):
"""
Create a torch autograd-enabled function from a 2D parallel-beam
:class:`odl.tomo.RayTransform` using tomosipo that calls the direct
forward projection routine of astra, which avoids copying between GPU and
CPU (available in 1.9.9.dev4).
Parameters
----------
ray_trafo : :class:`odl.tomo.RayTransform`
Ray transform
z_shape : int, optional
Channel dimension.
Default: ``1``.
Returns
-------
torch_ray_trafo : callable
Torch autograd-enabled function applying the parallel-beam forward
projection.
Input and output have a trivial leading batch dimension and a channel
dimension specified by `z_shape` (default ``1``), i.e. the
input shape is ``(1, z_shape) + ray_trafo.domain.shape`` and the
output shape is ``(1, z_shape) + ray_trafo.range.shape``.
"""
if not TOMOSIPO_AVAILABLE:
raise ImportError(MISSING_TOMOSIPO_MESSAGE)
if not ASTRA_AVAILABLE:
raise RuntimeError('Astra is not available.')
if not astra.use_cuda():
raise RuntimeError('Astra is not able to use CUDA.')
vg = from_odl(discretized_space_2d_to_3d(ray_trafo.domain,
z_shape=z_shape))
pg = from_odl(parallel_2d_to_3d_geometry(ray_trafo.geometry,
det_z_shape=z_shape))
ts_op = ts.operator(vg, pg)
torch_ray_trafo = to_autograd(ts_op)
return torch_ray_trafo
[docs]
def get_torch_ray_trafo_parallel_2d_adjoint(ray_trafo, z_shape=1):
"""
Create a torch autograd-enabled function from a 2D parallel-beam
:class:`odl.tomo.RayTransform` using tomosipo that calls the direct
backward projection routine of astra, which avoids copying between GPU and
CPU (available in 1.9.9.dev4).
Parameters
----------
ray_trafo : :class:`odl.tomo.RayTransform`
Ray transform
z_shape : int, optional
Batch dimension.
Default: ``1``.
Returns
-------
torch_ray_trafo_adjoint : callable
Torch autograd-enabled function applying the parallel-beam backward
projection.
Input and output have a trivial leading batch dimension and a channel
dimension specified by `z_shape` (default ``1``), i.e. the
input shape is ``(1, z_shape) + ray_trafo.range.shape`` and the
output shape is ``(1, z_shape) + ray_trafo.domain.shape``.
"""
if not TOMOSIPO_AVAILABLE:
raise ImportError(MISSING_TOMOSIPO_MESSAGE)
if not ASTRA_AVAILABLE:
raise RuntimeError('Astra is not available.')
if not astra.use_cuda():
raise RuntimeError('Astra is not able to use CUDA.')
vg = from_odl(discretized_space_2d_to_3d(ray_trafo.domain,
z_shape=z_shape))
pg = from_odl(parallel_2d_to_3d_geometry(ray_trafo.geometry,
det_z_shape=z_shape))
ts_op = ts.operator(vg, pg)
torch_ray_trafo_adjoint_ts = to_autograd(ts_op.T)
scaling_factor = astra_cuda_bp_scaling_factor(
ray_trafo.range, ray_trafo.domain, ray_trafo.geometry)
def torch_ray_trafo_adjoint(y):
return scaling_factor * torch_ray_trafo_adjoint_ts(y)
return torch_ray_trafo_adjoint
[docs]
def load_state_dict_convert_data_parallel(model, state_dict):
"""
Load a state dict into a model, while automatically converting the weight
names if :attr:`model` is a :class:`nn.DataParallel`-model but the stored
state dict stems from a non-data-parallel model, or vice versa.
Parameters
----------
model : nn.Module
Torch model that should load the state dict.
state_dict : dict
Torch state dict
Raises
------
RuntimeError
If there are missing or unexpected keys in the state dict.
This error is not raised when conversion of the weight names succeeds.
"""
missing_keys, unexpected_keys = model.load_state_dict(
state_dict, strict=False)
if missing_keys or unexpected_keys:
# since directly loading failed, assume now that state_dict's
# keys are named in the other way compared to type(model)
if isinstance(model, torch.nn.DataParallel):
state_dict = {('module.' + k): v
for k, v in state_dict.items()}
missing_keys2, unexpected_keys2 = (
model.load_state_dict(state_dict, strict=False))
if missing_keys2 or unexpected_keys2:
if len(missing_keys2) < len(missing_keys):
raise RuntimeError(
'Failed to load learned weights. Missing keys (in '
'case of prefixing with \'module.\', which lead to '
'fewer missing keys):\n{}'
.format(', '.join(
('"{}"'.format(k) for k in missing_keys2))))
else:
raise RuntimeError(
'Failed to load learned weights (also when trying '
'with additional \'module.\' prefix). Missing '
'keys:\n{}'
.format(', '.join(
('"{}"'.format(k) for k in missing_keys))))
else:
if all(k.startswith('module.') for k in state_dict.keys()):
state_dict = {k[len('module.'):]: v
for k, v in state_dict.items()}
missing_keys2, unexpected_keys2 = (
model.load_state_dict(state_dict,
strict=False))
if missing_keys2 or unexpected_keys2:
if len(missing_keys2) < len(missing_keys):
raise RuntimeError(
'Failed to load learned weights. Missing keys (in '
'case of removing \'module.\' prefix, which lead '
'to fewer missing keys):\n{}'
.format(', '.join(
('"{}"'.format(k) for k in missing_keys2))))
else:
raise RuntimeError(
'Failed to load learned weights (also when '
'removing \'module.\' prefix). Missing keys:\n{}'
.format(', '.join(
('"{}"'.format(k) for k in missing_keys))))
else:
raise RuntimeError(
'Failed to load learned weights. Missing keys:\n{}'
.format(', '.join(
('"{}"'.format(k) for k in missing_keys))))