Source code for dival.reconstructors.networks.unet

import torch
import torch.nn as nn
import numpy as np


[docs]def get_unet_model(in_ch=1, out_ch=1, scales=5, skip=4, channels=(32, 32, 64, 64, 128, 128), use_sigmoid=True, use_norm=True): assert (1 <= scales <= 6) skip_channels = [skip] * (scales) return UNet(in_ch=in_ch, out_ch=out_ch, channels=channels[:scales], skip_channels=skip_channels, use_sigmoid=use_sigmoid, use_norm=use_norm)
[docs]class UNet(nn.Module):
[docs] def __init__(self, in_ch, out_ch, channels, skip_channels, use_sigmoid=True, use_norm=True): super(UNet, self).__init__() assert (len(channels) == len(skip_channels)) self.scales = len(channels) self.use_sigmoid = use_sigmoid self.down = nn.ModuleList() self.up = nn.ModuleList() self.inc = InBlock(in_ch, channels[0], use_norm=use_norm) for i in range(1, self.scales): self.down.append(DownBlock(in_ch=channels[i - 1], out_ch=channels[i], use_norm=use_norm)) for i in range(1, self.scales): self.up.append(UpBlock(in_ch=channels[-i], out_ch=channels[-i - 1], skip_ch=skip_channels[-i], use_norm=use_norm)) self.outc = OutBlock(in_ch=channels[0], out_ch=out_ch)
[docs] def forward(self, x0): xs = [self.inc(x0), ] for i in range(self.scales - 1): xs.append(self.down[i](xs[-1])) x = xs[-1] for i in range(self.scales - 1): x = self.up[i](x, xs[-2 - i]) return torch.sigmoid(self.outc(x)) if self.use_sigmoid else self.outc(x)
[docs]class DownBlock(nn.Module):
[docs] def __init__(self, in_ch, out_ch, kernel_size=3, use_norm=True): super(DownBlock, self).__init__() to_pad = int((kernel_size - 1) / 2) if use_norm: self.conv = nn.Sequential( nn.Conv2d(in_ch, out_ch, kernel_size, stride=2, padding=to_pad), nn.BatchNorm2d(out_ch), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(out_ch, out_ch, kernel_size, stride=1, padding=to_pad), nn.BatchNorm2d(out_ch), nn.LeakyReLU(0.2, inplace=True)) else: self.conv = nn.Sequential( nn.Conv2d(in_ch, out_ch, kernel_size, stride=2, padding=to_pad), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(out_ch, out_ch, kernel_size, stride=1, padding=to_pad), nn.LeakyReLU(0.2, inplace=True))
[docs] def forward(self, x): x = self.conv(x) return x
[docs]class InBlock(nn.Module):
[docs] def __init__(self, in_ch, out_ch, kernel_size=3, use_norm=True): super(InBlock, self).__init__() to_pad = int((kernel_size - 1) / 2) if use_norm: self.conv = nn.Sequential( nn.Conv2d(in_ch, out_ch, kernel_size, stride=1, padding=to_pad), nn.BatchNorm2d(out_ch), nn.LeakyReLU(0.2, inplace=True)) else: self.conv = nn.Sequential( nn.Conv2d(in_ch, out_ch, kernel_size, stride=1, padding=to_pad), nn.LeakyReLU(0.2, inplace=True))
[docs] def forward(self, x): x = self.conv(x) return x
[docs]class UpBlock(nn.Module):
[docs] def __init__(self, in_ch, out_ch, skip_ch=4, kernel_size=3, use_norm=True): super(UpBlock, self).__init__() to_pad = int((kernel_size - 1) / 2) self.skip = skip_ch > 0 if skip_ch == 0: skip_ch = 1 if use_norm: self.conv = nn.Sequential( nn.BatchNorm2d(in_ch + skip_ch), nn.Conv2d(in_ch + skip_ch, out_ch, kernel_size, stride=1, padding=to_pad), nn.BatchNorm2d(out_ch), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(out_ch, out_ch, kernel_size, stride=1, padding=to_pad), nn.BatchNorm2d(out_ch), nn.LeakyReLU(0.2, inplace=True)) else: self.conv = nn.Sequential( nn.Conv2d(in_ch + skip_ch, out_ch, kernel_size, stride=1, padding=to_pad), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(out_ch, out_ch, kernel_size, stride=1, padding=to_pad), nn.LeakyReLU(0.2, inplace=True)) if use_norm: self.skip_conv = nn.Sequential( nn.Conv2d(out_ch, skip_ch, kernel_size=1, stride=1), nn.BatchNorm2d(skip_ch), nn.LeakyReLU(0.2, inplace=True)) else: self.skip_conv = nn.Sequential( nn.Conv2d(out_ch, skip_ch, kernel_size=1, stride=1), nn.LeakyReLU(0.2, inplace=True)) self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) self.concat = Concat()
[docs] def forward(self, x1, x2): x1 = self.up(x1) x2 = self.skip_conv(x2) if not self.skip: x2 = x2 * 0 x = self.concat(x1, x2) x = self.conv(x) return x
[docs]class Concat(nn.Module):
[docs] def __init__(self): super(Concat, self).__init__()
[docs] def forward(self, *inputs): inputs_shapes2 = [x.shape[2] for x in inputs] inputs_shapes3 = [x.shape[3] for x in inputs] if (np.all(np.array(inputs_shapes2) == min(inputs_shapes2)) and np.all(np.array(inputs_shapes3) == min(inputs_shapes3))): inputs_ = inputs else: target_shape2 = min(inputs_shapes2) target_shape3 = min(inputs_shapes3) inputs_ = [] for inp in inputs: diff2 = (inp.size(2) - target_shape2) // 2 diff3 = (inp.size(3) - target_shape3) // 2 inputs_.append(inp[:, :, diff2: diff2 + target_shape2, diff3:diff3 + target_shape3]) return torch.cat(inputs_, dim=1)
[docs]class OutBlock(nn.Module):
[docs] def __init__(self, in_ch, out_ch): super(OutBlock, self).__init__() self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=1)
[docs] def forward(self, x): x = self.conv(x) return x
def __len__(self): return len(self._modules)