dival.util.torch_utility module
Provides utilities related to PyTorch.
The classes and functions
in this module rely on the
tomosipo library and experimental
astra features available in version 1.9.9.dev4 using CUDA.
In order to instantiate or call these classes and functions, all of these
requirements need to be fulfilled, otherwise an ImportError
is raised.
- class dival.util.torch_utility.RandomAccessTorchDataset(dataset, part, reshape=None, transform=None)[source]
Bases:
Dataset
- class dival.util.torch_utility.GeneratorTorchDataset(dataset, part, reshape=None, transform=None)[source]
Bases:
IterableDataset
- class dival.util.torch_utility.TorchRayTrafoParallel2DModule(ray_trafo, init_z_shape=1)[source]
Bases:
Module
Torch module applying a 2D parallel-beam ray transform using tomosipo that calls the direct forward projection routine of astra, which avoids copying between GPU and CPU (available in 1.9.9.dev4).
All 2D transforms are computed using a single 3D transform. To this end the used tomosipo operator is renewed in
forward()
everytime the product of batch and channel dimensions of the current batch differs compared to the previous batch, or compared to the value of init_z_shape specified toinit()
for the first batch.- __init__(ray_trafo, init_z_shape=1)[source]
- Parameters:
ray_trafo (
odl.tomo.RayTransform
) – Ray transforminit_z_shape (int, optional) – Initial guess for the number of 2D transforms per batch, i.e. the product of batch and channel dimensions.
- forward(x)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class dival.util.torch_utility.TorchRayTrafoParallel2DAdjointModule(ray_trafo, init_z_shape=1)[source]
Bases:
Module
Torch module applying the adjoint of a 2D parallel-beam ray transform using tomosipo that calls the direct backward projection routine of astra, which avoids copying between GPU and CPU (available in 1.9.9.dev4).
All 2D transforms are computed using a single 3D transform. To this end the used tomosipo operator is renewed in
forward()
everytime the product of batch and channel dimensions of the current batch differs compared to the previous batch, or compared to the value of init_z_shape specified toinit()
for the first batch.- __init__(ray_trafo, init_z_shape=1)[source]
- Parameters:
ray_trafo (
odl.tomo.RayTransform
) – Ray transforminit_z_shape (int, optional) – Initial guess for the number of 2D transforms per batch, i.e. the product of batch and channel dimensions.
- forward(x)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- dival.util.torch_utility.get_torch_ray_trafo_parallel_2d(ray_trafo, z_shape=1)[source]
Create a torch autograd-enabled function from a 2D parallel-beam
odl.tomo.RayTransform
using tomosipo that calls the direct forward projection routine of astra, which avoids copying between GPU and CPU (available in 1.9.9.dev4).- Parameters:
ray_trafo (
odl.tomo.RayTransform
) – Ray transformz_shape (int, optional) – Channel dimension. Default:
1
.
- Returns:
torch_ray_trafo – Torch autograd-enabled function applying the parallel-beam forward projection. Input and output have a trivial leading batch dimension and a channel dimension specified by z_shape (default
1
), i.e. the input shape is(1, z_shape) + ray_trafo.domain.shape
and the output shape is(1, z_shape) + ray_trafo.range.shape
.- Return type:
callable
- dival.util.torch_utility.get_torch_ray_trafo_parallel_2d_adjoint(ray_trafo, z_shape=1)[source]
Create a torch autograd-enabled function from a 2D parallel-beam
odl.tomo.RayTransform
using tomosipo that calls the direct backward projection routine of astra, which avoids copying between GPU and CPU (available in 1.9.9.dev4).- Parameters:
ray_trafo (
odl.tomo.RayTransform
) – Ray transformz_shape (int, optional) – Batch dimension. Default:
1
.
- Returns:
torch_ray_trafo_adjoint – Torch autograd-enabled function applying the parallel-beam backward projection. Input and output have a trivial leading batch dimension and a channel dimension specified by z_shape (default
1
), i.e. the input shape is(1, z_shape) + ray_trafo.range.shape
and the output shape is(1, z_shape) + ray_trafo.domain.shape
.- Return type:
callable
- dival.util.torch_utility.load_state_dict_convert_data_parallel(model, state_dict)[source]
Load a state dict into a model, while automatically converting the weight names if
model
is ann.DataParallel
-model but the stored state dict stems from a non-data-parallel model, or vice versa.- Parameters:
model (nn.Module) – Torch model that should load the state dict.
state_dict (dict) – Torch state dict
- Raises:
RuntimeError – If there are missing or unexpected keys in the state dict. This error is not raised when conversion of the weight names succeeds.