| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272 |
- # The implementation is adopted from ECBSR,
- # made publicly available under the Apache 2.0 License at
- # https://github.com/xindongzhang/ECBSR/blob/main/models/ecb.py
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- class SeqConv3x3(nn.Module):
- def __init__(self, seq_type, inp_planes, out_planes, depth_multiplier):
- super(SeqConv3x3, self).__init__()
- self.type = seq_type
- self.inp_planes = inp_planes
- self.out_planes = out_planes
- if self.type == 'conv1x1-conv3x3':
- self.mid_planes = int(out_planes * depth_multiplier)
- conv0 = torch.nn.Conv2d(
- self.inp_planes, self.mid_planes, kernel_size=1, padding=0)
- self.k0 = conv0.weight
- self.b0 = conv0.bias
- conv1 = torch.nn.Conv2d(
- self.mid_planes, self.out_planes, kernel_size=3)
- self.k1 = conv1.weight
- self.b1 = conv1.bias
- elif self.type == 'conv1x1-sobelx':
- conv0 = torch.nn.Conv2d(
- self.inp_planes, self.out_planes, kernel_size=1, padding=0)
- self.k0 = conv0.weight
- self.b0 = conv0.bias
- # init scale & bias
- scale = torch.randn(size=(self.out_planes, 1, 1, 1)) * 1e-3
- self.scale = nn.Parameter(scale)
- # bias = 0.0
- # bias = [bias for c in range(self.out_planes)]
- # bias = torch.FloatTensor(bias)
- bias = torch.randn(self.out_planes) * 1e-3
- bias = torch.reshape(bias, (self.out_planes, ))
- self.bias = nn.Parameter(bias)
- # init mask
- self.mask = torch.zeros((self.out_planes, 1, 3, 3),
- dtype=torch.float32)
- for i in range(self.out_planes):
- self.mask[i, 0, 0, 0] = 1.0
- self.mask[i, 0, 1, 0] = 2.0
- self.mask[i, 0, 2, 0] = 1.0
- self.mask[i, 0, 0, 2] = -1.0
- self.mask[i, 0, 1, 2] = -2.0
- self.mask[i, 0, 2, 2] = -1.0
- self.mask = nn.Parameter(data=self.mask, requires_grad=False)
- elif self.type == 'conv1x1-sobely':
- conv0 = torch.nn.Conv2d(
- self.inp_planes, self.out_planes, kernel_size=1, padding=0)
- self.k0 = conv0.weight
- self.b0 = conv0.bias
- # init scale & bias
- scale = torch.randn(size=(self.out_planes, 1, 1, 1)) * 1e-3
- self.scale = nn.Parameter(torch.FloatTensor(scale))
- # bias = 0.0
- # bias = [bias for c in range(self.out_planes)]
- # bias = torch.FloatTensor(bias)
- bias = torch.randn(self.out_planes) * 1e-3
- bias = torch.reshape(bias, (self.out_planes, ))
- self.bias = nn.Parameter(torch.FloatTensor(bias))
- # init mask
- self.mask = torch.zeros((self.out_planes, 1, 3, 3),
- dtype=torch.float32)
- for i in range(self.out_planes):
- self.mask[i, 0, 0, 0] = 1.0
- self.mask[i, 0, 0, 1] = 2.0
- self.mask[i, 0, 0, 2] = 1.0
- self.mask[i, 0, 2, 0] = -1.0
- self.mask[i, 0, 2, 1] = -2.0
- self.mask[i, 0, 2, 2] = -1.0
- self.mask = nn.Parameter(data=self.mask, requires_grad=False)
- elif self.type == 'conv1x1-laplacian':
- conv0 = torch.nn.Conv2d(
- self.inp_planes, self.out_planes, kernel_size=1, padding=0)
- self.k0 = conv0.weight
- self.b0 = conv0.bias
- # init scale & bias
- scale = torch.randn(size=(self.out_planes, 1, 1, 1)) * 1e-3
- self.scale = nn.Parameter(torch.FloatTensor(scale))
- # bias = 0.0
- # bias = [bias for c in range(self.out_planes)]
- # bias = torch.FloatTensor(bias)
- bias = torch.randn(self.out_planes) * 1e-3
- bias = torch.reshape(bias, (self.out_planes, ))
- self.bias = nn.Parameter(torch.FloatTensor(bias))
- # init mask
- self.mask = torch.zeros((self.out_planes, 1, 3, 3),
- dtype=torch.float32)
- for i in range(self.out_planes):
- self.mask[i, 0, 0, 1] = 1.0
- self.mask[i, 0, 1, 0] = 1.0
- self.mask[i, 0, 1, 2] = 1.0
- self.mask[i, 0, 2, 1] = 1.0
- self.mask[i, 0, 1, 1] = -4.0
- self.mask = nn.Parameter(data=self.mask, requires_grad=False)
- else:
- raise ValueError('the type of seqconv is not supported!')
- def forward(self, x):
- if self.type == 'conv1x1-conv3x3':
- # conv-1x1
- y0 = F.conv2d(input=x, weight=self.k0, bias=self.b0, stride=1)
- # explicitly padding with bias
- y0 = F.pad(y0, (1, 1, 1, 1), 'constant', 0)
- b0_pad = self.b0.view(1, -1, 1, 1)
- y0[:, :, 0:1, :] = b0_pad
- y0[:, :, -1:, :] = b0_pad
- y0[:, :, :, 0:1] = b0_pad
- y0[:, :, :, -1:] = b0_pad
- # conv-3x3
- y1 = F.conv2d(input=y0, weight=self.k1, bias=self.b1, stride=1)
- else:
- y0 = F.conv2d(input=x, weight=self.k0, bias=self.b0, stride=1)
- # explicitly padding with bias
- y0 = F.pad(y0, (1, 1, 1, 1), 'constant', 0)
- b0_pad = self.b0.view(1, -1, 1, 1)
- y0[:, :, 0:1, :] = b0_pad
- y0[:, :, -1:, :] = b0_pad
- y0[:, :, :, 0:1] = b0_pad
- y0[:, :, :, -1:] = b0_pad
- # conv-3x3
- y1 = F.conv2d(
- input=y0,
- weight=self.scale * self.mask,
- bias=self.bias,
- stride=1,
- groups=self.out_planes)
- return y1
- def rep_params(self):
- device = self.k0.get_device()
- if device < 0:
- device = None
- if self.type == 'conv1x1-conv3x3':
- # re-param conv kernel
- RK = F.conv2d(input=self.k1, weight=self.k0.permute(1, 0, 2, 3))
- # re-param conv bias
- RB = torch.ones(
- 1, self.mid_planes, 3, 3, device=device) * self.b0.view(
- 1, -1, 1, 1)
- RB = F.conv2d(input=RB, weight=self.k1).view(-1, ) + self.b1
- else:
- tmp = self.scale * self.mask
- k1 = torch.zeros((self.out_planes, self.out_planes, 3, 3),
- device=device)
- for i in range(self.out_planes):
- k1[i, i, :, :] = tmp[i, 0, :, :]
- b1 = self.bias
- # re-param conv kernel
- RK = F.conv2d(input=k1, weight=self.k0.permute(1, 0, 2, 3))
- # re-param conv bias
- RB = torch.ones(
- 1, self.out_planes, 3, 3, device=device) * self.b0.view(
- 1, -1, 1, 1)
- RB = F.conv2d(input=RB, weight=k1).view(-1, ) + b1
- return RK, RB
- class ECB(nn.Module):
- def __init__(self,
- inp_planes,
- out_planes,
- depth_multiplier,
- act_type='prelu',
- with_idt=False):
- super(ECB, self).__init__()
- self.depth_multiplier = depth_multiplier
- self.inp_planes = inp_planes
- self.out_planes = out_planes
- self.act_type = act_type
- if with_idt and (self.inp_planes == self.out_planes):
- self.with_idt = True
- else:
- self.with_idt = False
- self.conv3x3 = torch.nn.Conv2d(
- self.inp_planes, self.out_planes, kernel_size=3, padding=1)
- self.conv1x1_3x3 = SeqConv3x3('conv1x1-conv3x3', self.inp_planes,
- self.out_planes, self.depth_multiplier)
- self.conv1x1_sbx = SeqConv3x3('conv1x1-sobelx', self.inp_planes,
- self.out_planes, -1)
- self.conv1x1_sby = SeqConv3x3('conv1x1-sobely', self.inp_planes,
- self.out_planes, -1)
- self.conv1x1_lpl = SeqConv3x3('conv1x1-laplacian', self.inp_planes,
- self.out_planes, -1)
- if self.act_type == 'prelu':
- self.act = nn.PReLU(num_parameters=self.out_planes)
- elif self.act_type == 'relu':
- self.act = nn.ReLU(inplace=True)
- elif self.act_type == 'rrelu':
- self.act = nn.RReLU(lower=-0.05, upper=0.05)
- elif self.act_type == 'softplus':
- self.act = nn.Softplus()
- elif self.act_type == 'linear':
- pass
- else:
- raise ValueError('The type of activation if not support!')
- def forward(self, x):
- if self.training:
- y = self.conv3x3(x) + \
- self.conv1x1_3x3(x) + \
- self.conv1x1_sbx(x) + \
- self.conv1x1_sby(x) + \
- self.conv1x1_lpl(x)
- if self.with_idt:
- y += x
- else:
- RK, RB = self.rep_params()
- y = F.conv2d(input=x, weight=RK, bias=RB, stride=1, padding=1)
- if self.act_type != 'linear':
- y = self.act(y)
- return y
- def rep_params(self):
- K0, B0 = self.conv3x3.weight, self.conv3x3.bias
- K1, B1 = self.conv1x1_3x3.rep_params()
- K2, B2 = self.conv1x1_sbx.rep_params()
- K3, B3 = self.conv1x1_sby.rep_params()
- K4, B4 = self.conv1x1_lpl.rep_params()
- RK, RB = (K0 + K1 + K2 + K3 + K4), (B0 + B1 + B2 + B3 + B4)
- if self.with_idt:
- device = RK.get_device()
- if device < 0:
- device = None
- K_idt = torch.zeros(
- self.out_planes, self.out_planes, 3, 3, device=device)
- for i in range(self.out_planes):
- K_idt[i, i, 1, 1] = 1.0
- B_idt = 0.0
- RK, RB = RK + K_idt, RB + B_idt
- return RK, RB
- if __name__ == '__main__':
- # # test seq-conv
- x = torch.randn(1, 3, 5, 5).cuda()
- conv = SeqConv3x3('conv1x1-conv3x3', 3, 3, 2).cuda()
- y0 = conv(x)
- RK, RB = conv.rep_params()
- y1 = F.conv2d(input=x, weight=RK, bias=RB, stride=1, padding=1)
- print(y0 - y1)
- # test ecb
- x = torch.randn(1, 3, 5, 5).cuda() * 200
- ecb = ECB(3, 3, 2, act_type='linear', with_idt=True).cuda()
- y0 = ecb(x)
- RK, RB = ecb.rep_params()
- y1 = F.conv2d(input=x, weight=RK, bias=RB, stride=1, padding=1)
- print(y0 - y1)
|