Source code for dival.reconstructors.dip_ct_reconstructor

from warnings import warn
from functools import partial
from tqdm import tqdm
import torch
import numpy as np

from torch.optim import Adam
from torch.nn import MSELoss

from odl.contrib.torch import OperatorModule

from dival.reconstructors import IterativeReconstructor
from dival.reconstructors.networks.unet import UNet
from dival.util.torch_losses import poisson_loss, tv_loss
from dival.util.constants import MU_MAX

MIN = -1000
MAX = 1000


[docs]class DeepImagePriorCTReconstructor(IterativeReconstructor): """ CT reconstructor applying DIP with TV regularization (see [2]_). The DIP was introduced in [1]_. References ---------- .. [1] V. Lempitsky, A. Vedaldi, and D. Ulyanov, 2018, "Deep Image Prior". IEEE/CVF Conference on Computer Vision and Pattern Recognition. https://doi.org/10.1109/CVPR.2018.00984 .. [2] D. Otero Baguer, J. Leuschner, M. Schmidt, 2020, "Computed Tomography Reconstruction Using Deep Image Prior and Learned Reconstruction Methods". Inverse Problems. https://doi.org/10.1088/1361-6420/aba415 """ HYPER_PARAMS = { 'lr': {'default': 1e-3, 'range': [1e-5, 1e-1]}, 'gamma': {'default': 1e-4, 'range': [1e-7, 1e-0], 'grid_search_options': {'num_samples': 20}}, 'scales': {'default': 4, 'choices': [3, 4, 5, 6, 7]}, 'channels': {'default': [128] * 5}, 'skip_channels': {'default': [4] * 5}, 'iterations': {'default': 5000, 'range': [1, 50000]}, 'loss_function': {'default': 'mse', 'choices': ['mse', 'poisson']}, 'photons_per_pixel': # used by 'poisson' loss function {'default': 4096, 'range': [1000, 10000]}, 'mu_max': # used by 'poisson' loss function {'default': MU_MAX, 'range': [1., 10000.]} }
[docs] def __init__(self, ray_trafo, callback_func=None, callback_func_interval=100, show_pbar=True, torch_manual_seed=10, **kwargs): """ Parameters ---------- ray_trafo : `odl.tomo.operators.RayTransform` The forward operator callback_func : callable, optional Callable with signature ``callback_func(iteration, reconstruction, loss)`` that is called after every `callback_func_interval` iterations, starting after the first iteration. It is additionally called after the last iteration. Note that it differs from the inherited `IterativeReconstructor.callback` (which is also supported) in that the latter is of type :class:`odl.solvers.util.callback.Callback`, which only receives the reconstruction, such that the loss would have to be recomputed. callback_func_interval : int, optional Number of iterations between calls to `callback_func`. Default: `100`. show_pbar : bool, optional Whether to show a tqdm progress bar during reconstruction. torch_manual_seed : int, optional Fixed seed to set by ``torch.manual_seed`` before reconstruction. The default is `10`. It can be set to `None` or `False` to disable the manual seed. """ super().__init__( reco_space=ray_trafo.domain, observation_space=ray_trafo.range, **kwargs) self.callback_func = callback_func self.ray_trafo = ray_trafo self.ray_trafo_module = OperatorModule(self.ray_trafo) self.callback_func = callback_func self.callback_func_interval = callback_func_interval self.show_pbar = show_pbar self.torch_manual_seed = torch_manual_seed
[docs] def get_activation(self, layer_index): return self.model.layer_output(self.net_input, layer_index)
def _reconstruct(self, observation, *args, **kwargs): if self.torch_manual_seed: torch.random.manual_seed(self.torch_manual_seed) output_depth = 1 input_depth = 1 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") self.net_input = 0.1 * \ torch.randn(input_depth, *self.reco_space.shape)[None].to(device) self.model = UNet( input_depth, output_depth, channels=self.channels[:self.scales], skip_channels=self.skip_channels[:self.scales], use_sigmoid=True, use_norm=True).to(device) self.optimizer = Adam(self.model.parameters(), lr=self.lr) y_delta = torch.tensor(np.asarray(observation), dtype=torch.float32) y_delta = y_delta.view(1, 1, *y_delta.shape) y_delta = y_delta.to(device) if self.loss_function == 'mse': criterion = MSELoss() elif self.loss_function == 'poisson': criterion = partial(poisson_loss, photons_per_pixel=self.photons_per_pixel, mu_max=self.mu_max) else: warn('Unknown loss function, falling back to MSE') criterion = MSELoss() best_loss = np.inf best_output = self.model(self.net_input).detach() for i in tqdm(range(self.iterations), desc='DIP', disable=not self.show_pbar): self.optimizer.zero_grad() output = self.model(self.net_input) loss = criterion(self.ray_trafo_module(output), y_delta) + self.gamma * tv_loss(output) loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1) self.optimizer.step() for p in self.model.parameters(): p.data.clamp_(MIN, MAX) if loss.item() < best_loss: best_loss = loss.item() best_output = output.detach() if (self.callback_func is not None and (i % self.callback_func_interval == 0 or i == self.iterations-1)): self.callback_func( iteration=i, reconstruction=best_output[0, 0, ...].cpu().numpy(), loss=best_loss) if self.callback is not None: self.callback(self.reco_space.element( best_output[0, 0, ...].cpu().numpy())) return self.reco_space.element(best_output[0, 0, ...].cpu().numpy())