dival.util.torch_utility module

Provides utilities related to PyTorch.

The classes and functions

in this module rely on the tomosipo library and experimental astra features available in version 1.9.9.dev4 using CUDA. In order to instantiate or call these classes and functions, all of these requirements need to be fulfilled, otherwise an ImportError is raised.

class dival.util.torch_utility.RandomAccessTorchDataset(*args, **kwds)[source]

Bases: torch.utils.data.dataset.Dataset

__init__(dataset, part, reshape=None, transform=None)[source]

Initialize self. See help(type(self)) for accurate signature.

class dival.util.torch_utility.GeneratorTorchDataset(*args, **kwds)[source]

Bases: torch.utils.data.dataset.IterableDataset

__init__(dataset, part, reshape=None, transform=None)[source]

Initialize self. See help(type(self)) for accurate signature.

generate()[source]
class dival.util.torch_utility.TorchRayTrafoParallel2DModule(ray_trafo, init_z_shape=1)[source]

Bases: torch.nn.modules.module.Module

Torch module applying a 2D parallel-beam ray transform using tomosipo that calls the direct forward projection routine of astra, which avoids copying between GPU and CPU (available in 1.9.9.dev4).

All 2D transforms are computed using a single 3D transform. To this end the used tomosipo operator is renewed in forward() everytime the product of batch and channel dimensions of the current batch differs compared to the previous batch, or compared to the value of init_z_shape specified to init() for the first batch.

__init__(ray_trafo, init_z_shape=1)[source]
Parameters
  • ray_trafo (odl.tomo.RayTransform) – Ray transform

  • init_z_shape (int, optional) – Initial guess for the number of 2D transforms per batch, i.e. the product of batch and channel dimensions.

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class dival.util.torch_utility.TorchRayTrafoParallel2DAdjointModule(ray_trafo, init_z_shape=1)[source]

Bases: torch.nn.modules.module.Module

Torch module applying the adjoint of a 2D parallel-beam ray transform using tomosipo that calls the direct backward projection routine of astra, which avoids copying between GPU and CPU (available in 1.9.9.dev4).

All 2D transforms are computed using a single 3D transform. To this end the used tomosipo operator is renewed in forward() everytime the product of batch and channel dimensions of the current batch differs compared to the previous batch, or compared to the value of init_z_shape specified to init() for the first batch.

__init__(ray_trafo, init_z_shape=1)[source]
Parameters
  • ray_trafo (odl.tomo.RayTransform) – Ray transform

  • init_z_shape (int, optional) – Initial guess for the number of 2D transforms per batch, i.e. the product of batch and channel dimensions.

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
dival.util.torch_utility.get_torch_ray_trafo_parallel_2d(ray_trafo, z_shape=1)[source]

Create a torch autograd-enabled function from a 2D parallel-beam odl.tomo.RayTransform using tomosipo that calls the direct forward projection routine of astra, which avoids copying between GPU and CPU (available in 1.9.9.dev4).

Parameters
  • ray_trafo (odl.tomo.RayTransform) – Ray transform

  • z_shape (int, optional) – Channel dimension. Default: 1.

Returns

torch_ray_trafo – Torch autograd-enabled function applying the parallel-beam forward projection. Input and output have a trivial leading batch dimension and a channel dimension specified by z_shape (default 1), i.e. the input shape is (1, z_shape) + ray_trafo.domain.shape and the output shape is (1, z_shape) + ray_trafo.range.shape.

Return type

callable

dival.util.torch_utility.get_torch_ray_trafo_parallel_2d_adjoint(ray_trafo, z_shape=1)[source]

Create a torch autograd-enabled function from a 2D parallel-beam odl.tomo.RayTransform using tomosipo that calls the direct backward projection routine of astra, which avoids copying between GPU and CPU (available in 1.9.9.dev4).

Parameters
  • ray_trafo (odl.tomo.RayTransform) – Ray transform

  • z_shape (int, optional) – Batch dimension. Default: 1.

Returns

torch_ray_trafo_adjoint – Torch autograd-enabled function applying the parallel-beam backward projection. Input and output have a trivial leading batch dimension and a channel dimension specified by z_shape (default 1), i.e. the input shape is (1, z_shape) + ray_trafo.range.shape and the output shape is (1, z_shape) + ray_trafo.domain.shape.

Return type

callable

dival.util.torch_utility.load_state_dict_convert_data_parallel(model, state_dict)[source]

Load a state dict into a model, while automatically converting the weight names if model is a nn.DataParallel-model but the stored state dict stems from a non-data-parallel model, or vice versa.

Parameters
  • model (nn.Module) – Torch model that should load the state dict.

  • state_dict (dict) – Torch state dict

Raises

RuntimeError – If there are missing or unexpected keys in the state dict. This error is not raised when conversion of the weight names succeeds.