| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- #
- # The implementation here is modified based on
- # Jongho Choi(sweetcocoa@snu.ac.kr / Seoul National Univ., ESTsoft )
- # and publicly available at
- # https://github.com/sweetcocoa/DeepComplexUNetPyTorch
- import torch
- import torch.nn as nn
- from . import complex_nn
- from .se_module_complex import SELayer
- class Encoder(nn.Module):
- def __init__(self,
- in_channels,
- out_channels,
- kernel_size,
- stride,
- padding=None,
- complex=False,
- padding_mode='zeros'):
- super().__init__()
- if padding is None:
- padding = [(i - 1) // 2 for i in kernel_size] # 'SAME' padding
- if complex:
- conv = complex_nn.ComplexConv2d
- bn = complex_nn.ComplexBatchNorm2d
- else:
- conv = nn.Conv2d
- bn = nn.BatchNorm2d
- self.conv = conv(
- in_channels,
- out_channels,
- kernel_size=kernel_size,
- stride=stride,
- padding=padding,
- padding_mode=padding_mode)
- self.bn = bn(out_channels)
- self.relu = nn.LeakyReLU(inplace=True)
- def forward(self, x):
- x = self.conv(x)
- x = self.bn(x)
- x = self.relu(x)
- return x
- class Decoder(nn.Module):
- def __init__(self,
- in_channels,
- out_channels,
- kernel_size,
- stride,
- padding=(0, 0),
- complex=False):
- super().__init__()
- if complex:
- tconv = complex_nn.ComplexConvTranspose2d
- bn = complex_nn.ComplexBatchNorm2d
- else:
- tconv = nn.ConvTranspose2d
- bn = nn.BatchNorm2d
- self.transconv = tconv(
- in_channels,
- out_channels,
- kernel_size=kernel_size,
- stride=stride,
- padding=padding)
- self.bn = bn(out_channels)
- self.relu = nn.LeakyReLU(inplace=True)
- def forward(self, x):
- x = self.transconv(x)
- x = self.bn(x)
- x = self.relu(x)
- return x
- class UNet(nn.Module):
- def __init__(self,
- input_channels=1,
- complex=False,
- model_complexity=45,
- model_depth=20,
- padding_mode='zeros'):
- super().__init__()
- if complex:
- model_complexity = int(model_complexity // 1.414)
- self.set_size(
- model_complexity=model_complexity,
- input_channels=input_channels,
- model_depth=model_depth)
- self.encoders = []
- self.model_length = model_depth // 2
- self.fsmn = complex_nn.ComplexUniDeepFsmn(128, 128, 128)
- self.se_layers_enc = []
- self.fsmn_enc = []
- for i in range(self.model_length):
- fsmn_enc = complex_nn.ComplexUniDeepFsmn_L1(128, 128, 128)
- self.add_module('fsmn_enc{}'.format(i), fsmn_enc)
- self.fsmn_enc.append(fsmn_enc)
- module = Encoder(
- self.enc_channels[i],
- self.enc_channels[i + 1],
- kernel_size=self.enc_kernel_sizes[i],
- stride=self.enc_strides[i],
- padding=self.enc_paddings[i],
- complex=complex,
- padding_mode=padding_mode)
- self.add_module('encoder{}'.format(i), module)
- self.encoders.append(module)
- se_layer_enc = SELayer(self.enc_channels[i + 1], 8)
- self.add_module('se_layer_enc{}'.format(i), se_layer_enc)
- self.se_layers_enc.append(se_layer_enc)
- self.decoders = []
- self.fsmn_dec = []
- self.se_layers_dec = []
- for i in range(self.model_length):
- fsmn_dec = complex_nn.ComplexUniDeepFsmn_L1(128, 128, 128)
- self.add_module('fsmn_dec{}'.format(i), fsmn_dec)
- self.fsmn_dec.append(fsmn_dec)
- module = Decoder(
- self.dec_channels[i] * 2,
- self.dec_channels[i + 1],
- kernel_size=self.dec_kernel_sizes[i],
- stride=self.dec_strides[i],
- padding=self.dec_paddings[i],
- complex=complex)
- self.add_module('decoder{}'.format(i), module)
- self.decoders.append(module)
- if i < self.model_length - 1:
- se_layer_dec = SELayer(self.dec_channels[i + 1], 8)
- self.add_module('se_layer_dec{}'.format(i), se_layer_dec)
- self.se_layers_dec.append(se_layer_dec)
- if complex:
- conv = complex_nn.ComplexConv2d
- else:
- conv = nn.Conv2d
- linear = conv(self.dec_channels[-1], 1, 1)
- self.add_module('linear', linear)
- self.complex = complex
- self.padding_mode = padding_mode
- self.decoders = nn.ModuleList(self.decoders)
- self.encoders = nn.ModuleList(self.encoders)
- self.se_layers_enc = nn.ModuleList(self.se_layers_enc)
- self.se_layers_dec = nn.ModuleList(self.se_layers_dec)
- self.fsmn_enc = nn.ModuleList(self.fsmn_enc)
- self.fsmn_dec = nn.ModuleList(self.fsmn_dec)
- def forward(self, inputs):
- x = inputs
- # go down
- xs = []
- xs_se = []
- xs_se.append(x)
- for i, encoder in enumerate(self.encoders):
- xs.append(x)
- if i > 0:
- x = self.fsmn_enc[i](x)
- x = encoder(x)
- xs_se.append(self.se_layers_enc[i](x))
- # xs : x0=input x1 ... x9
- x = self.fsmn(x)
- p = x
- for i, decoder in enumerate(self.decoders):
- p = decoder(p)
- if i < self.model_length - 1:
- p = self.fsmn_dec[i](p)
- if i == self.model_length - 1:
- break
- if i < self.model_length - 2:
- p = self.se_layers_dec[i](p)
- p = torch.cat([p, xs_se[self.model_length - 1 - i]], dim=1)
- # cmp_spec: [12, 1, 513, 64, 2]
- cmp_spec = self.linear(p)
- return cmp_spec
- def set_size(self, model_complexity, model_depth=20, input_channels=1):
- if model_depth == 14:
- self.enc_channels = [
- input_channels, 128, 128, 128, 128, 128, 128, 128
- ]
- self.enc_kernel_sizes = [(5, 2), (5, 2), (5, 2), (5, 2), (5, 2),
- (5, 2), (2, 2)]
- self.enc_strides = [(2, 1), (2, 1), (2, 1), (2, 1), (2, 1), (2, 1),
- (2, 1)]
- self.enc_paddings = [(0, 1), (0, 1), (0, 1), (0, 1), (0, 1),
- (0, 1), (0, 1)]
- self.dec_channels = [64, 128, 128, 128, 128, 128, 128, 1]
- self.dec_kernel_sizes = [(2, 2), (5, 2), (5, 2), (5, 2), (6, 2),
- (5, 2), (5, 2)]
- self.dec_strides = [(2, 1), (2, 1), (2, 1), (2, 1), (2, 1), (2, 1),
- (2, 1)]
- self.dec_paddings = [(0, 1), (0, 1), (0, 1), (0, 1), (0, 1),
- (0, 1), (0, 1)]
- elif model_depth == 10:
- self.enc_channels = [
- input_channels,
- 16,
- 32,
- 64,
- 128,
- 256,
- ]
- self.enc_kernel_sizes = [(3, 3), (3, 3), (3, 3), (3, 3), (3, 3)]
- self.enc_strides = [(2, 1), (2, 1), (2, 1), (2, 1), (2, 1)]
- self.enc_paddings = [(0, 1), (0, 1), (0, 1), (0, 1), (0, 1)]
- self.dec_channels = [128, 128, 64, 32, 16, 1]
- self.dec_kernel_sizes = [(3, 3), (3, 3), (3, 3), (4, 3), (3, 3)]
- self.dec_strides = [(2, 1), (2, 1), (2, 1), (2, 1), (2, 1)]
- self.dec_paddings = [(0, 1), (0, 1), (0, 1), (0, 1), (0, 1)]
- elif model_depth == 20:
- self.enc_channels = [
- input_channels, model_complexity, model_complexity,
- model_complexity * 2, model_complexity * 2,
- model_complexity * 2, model_complexity * 2,
- model_complexity * 2, model_complexity * 2,
- model_complexity * 2, 128
- ]
- self.enc_kernel_sizes = [(7, 1), (1, 7), (6, 4), (7, 5), (5, 3),
- (5, 3), (5, 3), (5, 3), (5, 3), (5, 3)]
- self.enc_strides = [(1, 1), (1, 1), (2, 2), (2, 1), (2, 2), (2, 1),
- (2, 2), (2, 1), (2, 2), (2, 1)]
- self.enc_paddings = [
- (3, 0),
- (0, 3),
- None, # (0, 2),
- None,
- None, # (3,1),
- None, # (3,1),
- None, # (1,2),
- None,
- None,
- None
- ]
- self.dec_channels = [
- 0, model_complexity * 2, model_complexity * 2,
- model_complexity * 2, model_complexity * 2,
- model_complexity * 2, model_complexity * 2,
- model_complexity * 2, model_complexity * 2,
- model_complexity * 2, model_complexity * 2,
- model_complexity * 2
- ]
- self.dec_kernel_sizes = [(4, 3), (4, 2), (4, 3), (4, 2), (4, 3),
- (4, 2), (6, 3), (7, 4), (1, 7), (7, 1)]
- self.dec_strides = [(2, 1), (2, 2), (2, 1), (2, 2), (2, 1), (2, 2),
- (2, 1), (2, 2), (1, 1), (1, 1)]
- self.dec_paddings = [(1, 1), (1, 0), (1, 1), (1, 0), (1, 1),
- (1, 0), (2, 1), (2, 1), (0, 3), (3, 0)]
- else:
- raise ValueError('Unknown model depth : {}'.format(model_depth))
|