# Copyright 2022 OFA-Sys Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch import torch.nn as nn def drop_path(x, drop_prob: float = 0., training: bool = False): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, the original name is misleading as 'Drop Connect' is a.sh different form of dropout in a.sh separate paper... See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and argument names to 'drop path' rather than mix DropConnect as a.sh layer name and use 'survival rate' as the argument. """ if drop_prob == 0. or not training: return x keep_prob = 1 - drop_prob shape = (x.shape[0], ) + (1, ) * ( x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets random_tensor = keep_prob + torch.rand( shape, dtype=x.dtype, device=x.device) random_tensor.floor_() # binarize output = x.div(keep_prob) * random_tensor return output class DropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). """ def __init__(self, drop_prob=None): super(DropPath, self).__init__() self.drop_prob = drop_prob def forward(self, x): return drop_path(x, self.drop_prob, self.training) def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): """3x3 convolution with padding""" return nn.Conv2d( in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation) def conv1x1(in_planes, out_planes, stride=1): """1x1 convolution""" return nn.Conv2d( in_planes, out_planes, kernel_size=1, stride=stride, bias=False) class BasicBlock(nn.Module): expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1, norm_layer=None): super(BasicBlock, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d if groups != 1 or base_width != 64: raise ValueError( 'BasicBlock only supports groups=1 and base_width=64') if dilation > 1: raise NotImplementedError( 'Dilation > 1 not supported in BasicBlock') # Both self.conv1 and self.downsample layers downsample the input when stride != 1 self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = norm_layer(planes) self.relu = nn.ReLU(inplace=True) self.conv2 = conv3x3(planes, planes) self.bn2 = norm_layer(planes) self.downsample = downsample self.stride = stride def forward(self, x): assert False identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.relu(out) return out class Bottleneck(nn.Module): # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) # while original implementation places the stride at the first 1x1 convolution(self.conv1) # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. # This variant is also known as ResNet V1.5 and improves accuracy according to # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. expansion = 4 def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1, norm_layer=None, drop_path_rate=0.0): super(Bottleneck, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d width = int(planes * (base_width / 64.)) * groups # Both self.conv2 and self.downsample layers downsample the input when stride != 1 self.conv1 = conv1x1(inplanes, width) self.bn1 = norm_layer(width) self.conv2 = conv3x3(width, width, stride, groups, dilation) self.bn2 = norm_layer(width) self.conv3 = conv1x1(width, planes * self.expansion) self.bn3 = norm_layer(planes * self.expansion) self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride self.drop_path = DropPath( drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() def forward(self, x): identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) if self.downsample is not None: identity = self.downsample(x) out = identity + self.drop_path(out) out = self.relu(out) return out class ResNet(nn.Module): r""" Deep residual network, copy from https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py. You can see more details from https://arxiv.org/abs/1512.03385 step 1. Get image embedding with `7` as the patch image size, `2` as stride. step 2. Do layer normalization, relu activation and max pooling. step 3. Go through three times residual branch. """ def __init__(self, layers, zero_init_residual=False, groups=1, width_per_group=64, replace_stride_with_dilation=None, norm_layer=None, drop_path_rate=0.0): r""" Args: layers (`Tuple[int]`): There are three layers in resnet, so the length of layers should greater then three. And each element in `layers` is the number of `Bottleneck` in relative residual branch. zero_init_residual (`bool`, **optional**, default to `False`): Whether or not to zero-initialize the last BN in each residual branch. groups (`int`, **optional**, default to `1`): The number of groups. So far, only the value of `1` is supported. width_per_group (`int`, **optional**, default to `64`): The width in each group. So far, only the value of `64` is supported. replace_stride_with_dilation (`Tuple[bool]`, **optional**, default to `None`): Whether or not to replace stride with dilation in each residual branch. norm_layer (`torch.nn.Module`, **optional**, default to `None`): The normalization module. If `None`, will use `torch.nn.BatchNorm2d`. drop_path_rate (`float`, **optional**, default to 0.0): Drop path rate. See more details about drop path from https://arxiv.org/pdf/1605.07648v4.pdf. """ super(ResNet, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d self._norm_layer = norm_layer self.inplanes = 64 self.dilation = 1 if replace_stride_with_dilation is None: # each element in the tuple indicates if we should replace # the 2x2 stride with a dilated convolution instead replace_stride_with_dilation = [False, False, False] if len(replace_stride_with_dilation) != 3: raise ValueError('replace_stride_with_dilation should be None ' 'or a 3-element tuple, got {}'.format( replace_stride_with_dilation)) self.groups = groups self.base_width = width_per_group self.conv1 = nn.Conv2d( 3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = norm_layer(self.inplanes) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer( Bottleneck, 64, layers[0], drop_path_rate=drop_path_rate) self.layer2 = self._make_layer( Bottleneck, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0], drop_path_rate=drop_path_rate) self.layer3 = self._make_layer( Bottleneck, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1], drop_path_rate=drop_path_rate) for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_( m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, (nn.SyncBatchNorm, nn.BatchNorm2d, nn.GroupNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) # Zero-initialize the last BN in each residual branch, # so that the residual branch starts with zeros, and each residual block behaves like an identity. # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 if zero_init_residual: for m in self.modules(): if isinstance(m, Bottleneck): nn.init.constant_(m.bn3.weight, 0) elif isinstance(m, BasicBlock): nn.init.constant_(m.bn2.weight, 0) def _make_layer(self, block, planes, blocks, stride=1, dilate=False, drop_path_rate=0.0): r""" Making a single residual branch. step 1. If dilate==`True`, switch the value of dilate and stride. step 2. If the input dimension doesn't equal to th output output dimension in `block`, initialize a down sample module. step 3. Build a sequential of `blocks` number of `block`. Args: block (`torch.nn.Module`): The basic block in residual branch. planes (`int`): The output dimension of each basic block. blocks (`int`): The number of `block` in residual branch. stride (`int`, **optional**, default to `1`): The stride using in conv. dilate (`bool`, **optional**, default to `False`): Whether or not to replace dilate with stride. drop_path_rate (`float`, **optional**, default to 0.0): Drop path rate. See more details about drop path from https://arxiv.org/pdf/1605.07648v4.pdf. Returns: A sequential of basic layer with type `torch.nn.Sequential[block]` """ norm_layer = self._norm_layer downsample = None previous_dilation = self.dilation if dilate: self.dilation *= stride stride = 1 if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( conv1x1(self.inplanes, planes * block.expansion, stride), norm_layer(planes * block.expansion), ) layers = [] layers.append( block(self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer)) self.inplanes = planes * block.expansion dpr = [x.item() for x in torch.linspace(0, drop_path_rate, blocks)] for i in range(1, blocks): layers.append( block( self.inplanes, planes, groups=self.groups, base_width=self.base_width, dilation=self.dilation, norm_layer=norm_layer, drop_path_rate=dpr[i])) return nn.Sequential(*layers) def _forward_impl(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) return x def forward(self, x): return self._forward_impl(x)