Source code for dival.reconstructors.networks.iterative

import torch
import torch.nn as nn


[docs]class IterativeBlock(nn.Module):
[docs] def __init__(self, n_in=3, n_out=1, n_memory=5, n_layer=3, internal_ch=32, kernel_size=3, batch_norm=True, prelu=False, lrelu_coeff=0.2): super(IterativeBlock, self).__init__() assert kernel_size % 2 == 1 padding = (kernel_size - 1) // 2 modules = [] if batch_norm: modules.append(nn.BatchNorm2d(n_in + n_memory)) for i in range(n_layer-1): input_ch = (n_in + n_memory) if i == 0 else internal_ch modules.append(nn.Conv2d(input_ch, internal_ch, kernel_size=kernel_size, padding=padding)) if batch_norm: modules.append(nn.BatchNorm2d(internal_ch)) if prelu: modules.append(nn.PReLU(internal_ch, init=0.0)) else: modules.append(nn.LeakyReLU(lrelu_coeff, inplace=True)) modules.append(nn.Conv2d(internal_ch, n_out + n_memory, kernel_size=kernel_size, padding=padding)) self.block = nn.Sequential(*modules) self.relu = nn.LeakyReLU(lrelu_coeff, inplace=True) # remove?
[docs] def forward(self, x): upd = self.block(x) return upd
[docs]class IterativeNet(nn.Module):
[docs] def __init__(self, n_iter, n_memory, op, op_adj, op_init, op_reg, use_sigmoid=False, n_layer=4, internal_ch=32, kernel_size=3, batch_norm=True, prelu=False, lrelu_coeff=0.2): super(IterativeNet, self).__init__() self.n_iter = n_iter self.n_memory = n_memory self.op = op self.op_adj = op_adj self.op_init = op_init self.op_reg = op_reg self.use_sigmoid = use_sigmoid self.blocks = nn.ModuleList() for it in range(n_iter): self.blocks.append(IterativeBlock( n_in=3, n_out=1, n_memory=self.n_memory, n_layer=n_layer, internal_ch=internal_ch, kernel_size=kernel_size, batch_norm=batch_norm, prelu=prelu, lrelu_coeff=lrelu_coeff))
[docs] def forward(self, y, it=-1): if self.op_init is not None: x = self.op_init(y) else: x = torch.zeros(y.shape[0], 1, *self.op.operator.domain.shape, device=y.device) s = torch.zeros(x.shape[0], self.n_memory, *x.shape[2:], device=y.device) n_iter = self.n_iter if it == -1 else min(self.n_iter, it) for i in range(n_iter): grad_x = self.op_adj(self.op(x) - y) grad_reg = self.op_reg(x) update = self.blocks[i](torch.cat([x, grad_x, grad_reg, s], dim=1)) x += update[:, :1, ...] s = torch.relu(update[:, 1:, ...]) if self.use_sigmoid: x = torch.sigmoid(x) return x
[docs]class PrimalDualNet(nn.Module):
[docs] def __init__(self, n_iter, op, op_adj, op_init=None, n_primal=5, n_dual=5, use_sigmoid=False, n_layer=4, internal_ch=32, kernel_size=3, batch_norm=True, prelu=False, lrelu_coeff=0.2): super(PrimalDualNet, self).__init__() self.n_iter = n_iter self.op = op self.op_adj = op_adj self.op_init = op_init self.n_primal = n_primal self.n_dual = n_dual self.use_sigmoid = use_sigmoid self.primal_blocks = nn.ModuleList() self.dual_blocks = nn.ModuleList() for it in range(n_iter): self.dual_blocks.append(IterativeBlock( n_in=3, n_out=1, n_memory=self.n_dual-1, n_layer=n_layer, internal_ch=internal_ch, kernel_size=kernel_size, batch_norm=batch_norm, prelu=prelu, lrelu_coeff=lrelu_coeff)) self.primal_blocks.append(IterativeBlock( n_in=2, n_out=1, n_memory=self.n_primal-1, n_layer=n_layer, internal_ch=internal_ch, kernel_size=kernel_size, batch_norm=batch_norm, prelu=prelu, lrelu_coeff=lrelu_coeff))
[docs] def forward(self, y): primal_cur = torch.zeros(y.shape[0], self.n_primal, *self.op.operator.domain.shape, device=y.device) if self.op_init is not None: primal_cur[:] = self.op_init(y) # broadcast across dim=1 dual_cur = torch.zeros(y.shape[0], self.n_dual, *self.op_adj.operator.domain.shape, device=y.device) for i in range(self.n_iter): primal_evalop = self.op(primal_cur[:, 1:2, ...]) dual_update = torch.cat([dual_cur, primal_evalop, y], dim=1) dual_update = self.dual_blocks[i](dual_update) dual_cur = dual_cur + dual_update # NB: currently only linear op supported # for non-linear op: [d/dx self.op(primal_cur[0:1, ...])]* dual_evalop = self.op_adj(dual_cur[:, 0:1, ...]) primal_update = torch.cat([primal_cur, dual_evalop], dim=1) primal_update = self.primal_blocks[i](primal_update) primal_cur = primal_cur + primal_update x = primal_cur[:, 0:1, ...] if self.use_sigmoid: x = torch.sigmoid(x) return x