"""Provides custom loss functions for PyTorch."""
import torch
import numpy as np
from dival.util.constants import MU_MAX
[docs]
def tv_loss(x):
"""
Anisotropic TV loss similar to the one in [1]_.
Parameters
----------
x : :class:`torch.Tensor`
Tensor of which to compute the anisotropic TV w.r.t. its last two axes.
References
----------
.. [1] https://en.wikipedia.org/wiki/Total_variation_denoising
"""
dh = torch.abs(x[..., :, 1:] - x[..., :, :-1])
dw = torch.abs(x[..., 1:, :] - x[..., :-1, :])
return torch.sum(dh[..., :-1, :] + dw[..., :, :-1])
[docs]
def poisson_loss(y_pred, y_true, photons_per_pixel=4096, mu_max=MU_MAX):
"""
Loss corresponding to Poisson regression (cf. [2]_) for post-log CT data.
The default parameters are based on the LoDoPaB dataset creation
(cf. [3]_).
:Authors:
Sören Dittmer <sdittmer@math.uni-bremen.de>
Parameters
----------
y_pred : :class:`torch.Tensor`
Predicted observation (post-log, normalized by `mu_max`).
Each entry specifies the mean of a Poisson distribution, with respect
to which the likelihood of the observation ``y_true`` is considered.
y_true : :class:`torch.Tensor`
True observation (post-log, normalized by `mu_max`).
photons_per_pixel : int or float, optional
Mean number of photons per detector pixel for an unattenuated beam.
Default: `4096`.
mu_max : float, optional
Normalization factor, by which `y_pred` and `y_true` have
been divided (this function will multiply by it accordingly).
Default: ``dival.util.constants.MU_MAX``.
References
----------
.. [2] https://en.wikipedia.org/wiki/Poisson_regression
.. [3] https://github.com/jleuschn/lodopab_tech_ref/blob/master/create_dataset.py
"""
def get_photons(y):
y = torch.exp(-y * mu_max) * photons_per_pixel
return y
def get_photons_log(y):
y = -y * mu_max + np.log(photons_per_pixel)
return y
y_true_photons = get_photons(y_true)
y_pred_photons = get_photons(y_pred)
y_pred_photons_log = get_photons_log(y_pred)
return torch.sum(y_pred_photons - y_true_photons * y_pred_photons_log)