Source code for dival.reconstructors.tvadam_ct_reconstructor

from warnings import warn
from functools import partial
from tqdm import tqdm
import torch
import numpy as np

from torch.optim import Adam
from torch.nn import MSELoss

from odl.contrib.torch import OperatorModule
from odl.tomo import fbp_op

from dival.reconstructors import IterativeReconstructor
from dival.util.torch_losses import poisson_loss, tv_loss
from dival.util.constants import MU_MAX


[docs]class TVAdamCTReconstructor(IterativeReconstructor): """ CT reconstructor minimizing a TV-functional with the Adam optimizer. """ HYPER_PARAMS = { 'lr': {'default': 1e-3, 'range': [1e-5, 1e-1]}, 'gamma': {'default': 1e-4, 'range': [1e-7, 1e-0], 'grid_search_options': {'num_samples': 20}}, 'iterations': {'default': 5000, 'range': [1, 50000]}, 'loss_function': {'default': 'mse', 'choices': ['mse', 'poisson']}, 'photons_per_pixel': # used by 'poisson' loss function {'default': 4096}, 'mu_max': # used by 'poisson' loss function {'default': MU_MAX}, 'init_filter_type': {'default': 'Hann'}, 'init_frequency_scaling': {'default': 0.1} }
[docs] def __init__(self, ray_trafo, callback_func=None, callback_func_interval=100, show_pbar=True, **kwargs): """ Parameters ---------- ray_trafo : `odl.tomo.operators.RayTransform` The forward operator callback_func : callable, optional Callable with signature ``callback_func(iteration, reconstruction, loss)`` that is called after every `callback_func_interval` iterations, starting after the first iteration. It is additionally called after the last iteration. Note that it differs from the inherited `IterativeReconstructor.callback` (which is also supported) in that the latter is of type :class:`odl.solvers.util.callback.Callback`, which only receives the reconstruction, such that the loss would have to be recomputed. callback_func_interval : int, optional Number of iterations between calls to `callback_func`. Default: `100`. show_pbar : bool, optional Whether to show a tqdm progress bar during reconstruction. """ super().__init__( reco_space=ray_trafo.domain, observation_space=ray_trafo.range, **kwargs) self.callback_func = callback_func self.ray_trafo = ray_trafo self.ray_trafo_module = OperatorModule(self.ray_trafo) self.callback_func = callback_func self.callback_func_interval = callback_func_interval self.show_pbar = show_pbar
def _reconstruct(self, observation, *args, **kwargs): self.fbp_op = fbp_op( self.ray_trafo, filter_type=self.init_filter_type, frequency_scaling=self.init_frequency_scaling) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") self.output = torch.tensor(self.fbp_op(observation))[None].to(device) self.output.requires_grad = True self.optimizer = Adam([self.output], lr=self.lr) y_delta = torch.tensor(np.asarray(observation), dtype=torch.float32) y_delta = y_delta.view(1, *y_delta.shape) y_delta = y_delta.to(device) if self.loss_function == 'mse': criterion = MSELoss() elif self.loss_function == 'poisson': criterion = partial(poisson_loss, photons_per_pixel=self.photons_per_pixel, mu_max=self.mu_max) else: warn('Unknown loss function, falling back to MSE') criterion = MSELoss() best_loss = np.infty best_output = self.output.detach().clone() for i in tqdm(range(self.iterations), desc='TV', disable=not self.show_pbar): self.optimizer.zero_grad() loss = criterion(self.ray_trafo_module(self.output), y_delta) + self.gamma * tv_loss(self.output) loss.backward() self.optimizer.step() if loss.item() < best_loss: best_loss = loss.item() best_output = self.output.detach().clone() if (self.callback_func is not None and (i % self.callback_func_interval == 0 or i == self.iterations-1)): self.callback_func( iteration=i, reconstruction=best_output[0, ...].cpu().numpy(), loss=best_loss) if self.callback is not None: self.callback(self.reco_space.element( best_output[0, ...].cpu().numpy())) return self.reco_space.element(best_output[0, ...].cpu().numpy())