import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
import numpy as np
from dival.reconstructors.networks.unet import get_unet_model
[docs]
def get_iradonmap_model(ray_trafo, fully_learned, scales=5, skip=4,
channels=(32, 32, 64, 64, 128, 128), use_sigmoid=True,
use_norm=True, coord_mat=None):
post_process = get_unet_model(in_ch=1, out_ch=1, scales=scales, skip=skip,
channels=channels, use_sigmoid=use_sigmoid,
use_norm=use_norm)
return IRadonMap(ray_trafo=ray_trafo, post_process=post_process,
fully_learned=fully_learned, coord_mat=coord_mat)
[docs]
class IRadonMap(nn.Module):
[docs]
def __init__(self, ray_trafo, post_process, fully_learned, coord_mat=None):
super(IRadonMap, self).__init__()
self.num_detectors = ray_trafo.range.shape[-1]
self.linear_layer = nn.Linear(in_features=self.num_detectors,
out_features=self.num_detectors,
bias=False)
self.adj_ray = LearnedBackprojection(ray_trafo=ray_trafo,
use_weights=fully_learned,
coord_mat=coord_mat)
self.post_process = post_process
[docs]
def forward(self, x):
x = self.linear_layer(x)
x = self.adj_ray(x)
x = self.post_process(x)
return x
[docs]
class LearnedBackprojection(nn.Module):
[docs]
def __init__(self, ray_trafo, use_weights, coord_mat=None):
super(LearnedBackprojection, self).__init__()
self.x_range = ray_trafo.domain.shape[0]
self.y_range = ray_trafo.domain.shape[1]
self.num_angles = ray_trafo.range.shape[0]
self.num_detectors = ray_trafo.range.shape[1]
if use_weights:
self.weights = nn.init.ones_(Parameter(
torch.Tensor(1, 1, self.x_range, self.y_range,
self.num_angles)))
else:
self.weights = 1.0
self.coord_mat = (self.calc_coord_mat() if coord_mat is None else
coord_mat)
[docs]
def forward(self, sinogram):
sinogram = sinogram.reshape((sinogram.size()[0], sinogram.size()[1],
sinogram.size()[2]*sinogram.size()[3]))
x = sinogram[:, :, self.coord_mat]
x = torch.sum(x * self.weights, dim=-1)
return x
[docs]
def calc_coord_mat(self):
x_shift = int(self.x_range / 2)
y_shift = int(self.y_range / 2)
angle_step = np.pi / self.num_angles
coord_matrix = np.empty((self.x_range, self.y_range,
self.num_angles), dtype=np.int32)
for theta in range(self.num_angles):
angle = angle_step*theta
x = np.arange(self.x_range)
y = np.arange(self.y_range)
coord_matrix[:, :, theta] = \
np.around((x[:, None] - x_shift) * np.cos(angle) +
(y[None, :] - y_shift) * np.sin(angle))
s_shift = np.abs(np.amin(coord_matrix))
coord_matrix = coord_matrix + s_shift
for theta in range(self.num_angles):
coord_matrix[:, :, theta] = coord_matrix[:, :, theta] + \
theta * self.num_detectors
return coord_matrix