# The implementation is adopted from mmsegmentation, # made publicly available under the Apache License, Version 2.0 at https://github.com/open-mmlab/mmsegmentation from abc import ABCMeta, abstractmethod import mmcv import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from mmcv.cnn import ConvModule from mmcv.runner import BaseModule, auto_fp16, force_fp32 # classification head class LinearClassifier(nn.Module): def __init__(self, in_channels, num_classes): super(LinearClassifier, self).__init__() self.classifier = nn.Linear(in_channels, num_classes) def forward(self, x): return self.classifier(x[-1]) # segmentation head def resize(input, size=None, scale_factor=None, mode='nearest', align_corners=None, warning=True): return F.interpolate(input, size, scale_factor, mode, align_corners) class Upsample(nn.Module): def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=None): super(Upsample, self).__init__() self.size = size if isinstance(scale_factor, tuple): self.scale_factor = tuple(float(factor) for factor in scale_factor) else: self.scale_factor = float(scale_factor) if scale_factor else None self.mode = mode self.align_corners = align_corners def forward(self, x): if not self.size: size = [int(t * self.scale_factor) for t in x.shape[-2:]] else: size = self.size return resize(x, size, None, self.mode, self.align_corners) class FPN(BaseModule): """Feature Pyramid Network. This neck is the implementation of `Feature Pyramid Networks for Object Detection `_. Args: in_channels (list[int]): Number of input channels per scale. out_channels (int): Number of output channels (used at each scale). num_outs (int): Number of output scales. start_level (int): Index of the start input backbone level used to build the feature pyramid. Default: 0. end_level (int): Index of the end input backbone level (exclusive) to build the feature pyramid. Default: -1, which means the last level. add_extra_convs (bool | str): If bool, it decides whether to add conv layers on top of the original feature maps. Default to False. If True, its actual mode is specified by `extra_convs_on_inputs`. If str, it specifies the source feature map of the extra convs. Only the following options are allowed - 'on_input': Last feat map of neck inputs (i.e. backbone feature). - 'on_lateral': Last feature map after lateral convs. - 'on_output': The last output feature map after fpn convs. extra_convs_on_inputs (bool, deprecated): Whether to apply extra convs on the original feature from the backbone. If True, it is equivalent to `add_extra_convs='on_input'`. If False, it is equivalent to set `add_extra_convs='on_output'`. Default to True. relu_before_extra_convs (bool): Whether to apply relu before the extra conv. Default: False. no_norm_on_lateral (bool): Whether to apply norm on lateral. Default: False. conv_cfg (dict): Config dict for convolution layer. Default: None. norm_cfg (dict): Config dict for normalization layer. Default: None. act_cfg (dict): Config dict for activation layer in ConvModule. Default: None. upsample_cfg (dict): Config dict for interpolate layer. Default: dict(mode='nearest'). init_cfg (dict or list[dict], optional): Initialization config dict. Example: >>> import torch >>> in_channels = [2, 3, 5, 7] >>> scales = [340, 170, 84, 43] >>> inputs = [torch.rand(1, c, s, s) ... for c, s in zip(in_channels, scales)] >>> self = FPN(in_channels, 11, len(in_channels)).eval() >>> outputs = self.forward(inputs) >>> for i in range(len(outputs)): ... print(f'outputs[{i}].shape = {outputs[i].shape}') outputs[0].shape = torch.Size([1, 11, 340, 340]) outputs[1].shape = torch.Size([1, 11, 170, 170]) outputs[2].shape = torch.Size([1, 11, 84, 84]) outputs[3].shape = torch.Size([1, 11, 43, 43]) """ def __init__(self, in_channels, out_channels, num_outs, start_level=0, end_level=-1, add_extra_convs=False, extra_convs_on_inputs=False, relu_before_extra_convs=False, no_norm_on_lateral=False, conv_cfg=None, norm_cfg=None, act_cfg=None, upsample_cfg=dict(mode='nearest'), init_cfg=dict( type='Xavier', layer='Conv2d', distribution='uniform')): super(FPN, self).__init__(init_cfg) assert isinstance(in_channels, list) self.in_channels = in_channels self.out_channels = out_channels self.num_ins = len(in_channels) self.num_outs = num_outs self.relu_before_extra_convs = relu_before_extra_convs self.no_norm_on_lateral = no_norm_on_lateral self.fp16_enabled = False self.upsample_cfg = upsample_cfg.copy() if end_level == -1: self.backbone_end_level = self.num_ins assert num_outs >= self.num_ins - start_level else: # if end_level < inputs, no extra level is allowed self.backbone_end_level = end_level assert end_level <= len(in_channels) assert num_outs == end_level - start_level self.start_level = start_level self.end_level = end_level self.add_extra_convs = add_extra_convs assert isinstance(add_extra_convs, (str, bool)) if isinstance(add_extra_convs, str): # Extra_convs_source choices: 'on_input', 'on_lateral', 'on_output' assert add_extra_convs in ('on_input', 'on_lateral', 'on_output') elif add_extra_convs: # True if extra_convs_on_inputs: # For compatibility with previous release # TODO: deprecate `extra_convs_on_inputs` self.add_extra_convs = 'on_input' else: self.add_extra_convs = 'on_output' self.lateral_convs = nn.ModuleList() self.fpn_convs = nn.ModuleList() for i in range(self.start_level, self.backbone_end_level): l_conv = ConvModule( in_channels[i], out_channels, 1, conv_cfg=conv_cfg, norm_cfg=norm_cfg if not self.no_norm_on_lateral else None, act_cfg=act_cfg, inplace=False) fpn_conv = ConvModule( out_channels, out_channels, 3, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg, inplace=False) self.lateral_convs.append(l_conv) self.fpn_convs.append(fpn_conv) # add extra conv layers (e.g., RetinaNet) extra_levels = num_outs - self.backbone_end_level + self.start_level if self.add_extra_convs and extra_levels >= 1: for i in range(extra_levels): if i == 0 and self.add_extra_convs == 'on_input': in_channels = self.in_channels[self.backbone_end_level - 1] else: in_channels = out_channels extra_fpn_conv = ConvModule( in_channels, out_channels, 3, stride=2, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg, inplace=False) self.fpn_convs.append(extra_fpn_conv) @auto_fp16() def forward(self, inputs): assert len(inputs) == len(self.in_channels) # build laterals laterals = [ lateral_conv(inputs[i + self.start_level]) for i, lateral_conv in enumerate(self.lateral_convs) ] # build top-down path used_backbone_levels = len(laterals) for i in range(used_backbone_levels - 1, 0, -1): # In some cases, fixing `scale factor` (e.g. 2) is preferred, but # it cannot co-exist with `size` in `F.interpolate`. if 'scale_factor' in self.upsample_cfg: laterals[i - 1] = laterals[i - 1] + resize( laterals[i], **self.upsample_cfg) else: prev_shape = laterals[i - 1].shape[2:] laterals[i - 1] = laterals[i - 1] + resize( laterals[i], size=prev_shape, **self.upsample_cfg) # build outputs # part 1: from original levels outs = [ self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels) ] # part 2: add extra levels if self.num_outs > len(outs): # use max pool to get more levels on top of outputs # (e.g., Faster R-CNN, Mask R-CNN) if not self.add_extra_convs: for i in range(self.num_outs - used_backbone_levels): outs.append(F.max_pool2d(outs[-1], 1, stride=2)) # add conv layers on top of original feature maps (RetinaNet) else: if self.add_extra_convs == 'on_input': extra_source = inputs[self.backbone_end_level - 1] elif self.add_extra_convs == 'on_lateral': extra_source = laterals[-1] elif self.add_extra_convs == 'on_output': extra_source = outs[-1] else: raise NotImplementedError outs.append(self.fpn_convs[used_backbone_levels](extra_source)) for i in range(used_backbone_levels + 1, self.num_outs): if self.relu_before_extra_convs: outs.append(self.fpn_convs[i](F.relu(outs[-1]))) else: outs.append(self.fpn_convs[i](outs[-1])) return tuple(outs) class BaseDecodeHead(BaseModule, metaclass=ABCMeta): """Base class for BaseDecodeHead. Args: in_channels (int|Sequence[int]): Input channels. channels (int): Channels after modules, before conv_seg. num_classes (int): Number of classes. out_channels (int): Output channels of conv_seg. threshold (float): Threshold for binary segmentation in the case of `out_channels==1`. Default: None. dropout_ratio (float): Ratio of dropout layer. Default: 0.1. conv_cfg (dict|None): Config of conv layers. Default: None. norm_cfg (dict|None): Config of norm layers. Default: None. act_cfg (dict): Config of activation layers. Default: dict(type='ReLU') in_index (int|Sequence[int]): Input feature index. Default: -1 input_transform (str|None): Transformation type of input features. Options: 'resize_concat', 'multiple_select', None. 'resize_concat': Multiple feature maps will be resize to the same size as first one and than concat together. Usually used in FCN head of HRNet. 'multiple_select': Multiple feature maps will be bundle into a list and passed into decode head. None: Only one select feature map is allowed. Default: None. loss_decode (dict | Sequence[dict]): Config of decode loss. The `loss_name` is property of corresponding loss function which could be shown in training log. If you want this loss item to be included into the backward graph, `loss_` must be the prefix of the name. Defaults to 'loss_ce'. e.g. dict(type='CrossEntropyLoss'), [dict(type='CrossEntropyLoss', loss_name='loss_ce'), dict(type='DiceLoss', loss_name='loss_dice')] Default: dict(type='CrossEntropyLoss'). ignore_index (int | None): The label index to be ignored. When using masked BCE loss, ignore_index should be set to None. Default: 255. sampler (dict|None): The config of segmentation map sampler. Default: None. align_corners (bool): align_corners argument of F.interpolate. Default: False. init_cfg (dict or list[dict], optional): Initialization config dict. """ def __init__(self, in_channels, channels, *, num_classes, out_channels=None, threshold=None, dropout_ratio=0.1, conv_cfg=None, norm_cfg=None, act_cfg=dict(type='ReLU'), in_index=-1, input_transform=None, loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), ignore_index=255, sampler=None, align_corners=False, init_cfg=dict( type='Normal', std=0.01, override=dict(name='conv_seg'))): super(BaseDecodeHead, self).__init__(init_cfg) self._init_inputs(in_channels, in_index, input_transform) self.channels = channels self.dropout_ratio = dropout_ratio self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.act_cfg = act_cfg self.in_index = in_index self.ignore_index = ignore_index self.align_corners = align_corners if out_channels is None: if num_classes == 2: warnings.warn('For binary segmentation, we suggest using' '`out_channels = 1` to define the output' 'channels of segmentor, and use `threshold`' 'to convert seg_logist into a prediction' 'applying a threshold') out_channels = num_classes if out_channels != num_classes and out_channels != 1: raise ValueError( 'out_channels should be equal to num_classes,' 'except binary segmentation set out_channels == 1 and' f'num_classes == 2, but got out_channels={out_channels}' f'and num_classes={num_classes}') if out_channels == 1 and threshold is None: threshold = 0.3 warnings.warn('threshold is not defined for binary, and defaults' 'to 0.3') self.num_classes = num_classes self.out_channels = out_channels self.threshold = threshold self.conv_seg = nn.Conv2d(channels, self.out_channels, kernel_size=1) if dropout_ratio > 0: self.dropout = nn.Dropout2d(dropout_ratio) else: self.dropout = None self.fp16_enabled = False def extra_repr(self): """Extra repr.""" s = f'input_transform={self.input_transform}, ' \ f'ignore_index={self.ignore_index}, ' \ f'align_corners={self.align_corners}' return s def _init_inputs(self, in_channels, in_index, input_transform): """Check and initialize input transforms. The in_channels, in_index and input_transform must match. Specifically, when input_transform is None, only single feature map will be selected. So in_channels and in_index must be of type int. When input_transform Args: in_channels (int|Sequence[int]): Input channels. in_index (int|Sequence[int]): Input feature index. input_transform (str|None): Transformation type of input features. Options: 'resize_concat', 'multiple_select', None. 'resize_concat': Multiple feature maps will be resize to the same size as first one and than concat together. Usually used in FCN head of HRNet. 'multiple_select': Multiple feature maps will be bundle into a list and passed into decode head. None: Only one select feature map is allowed. """ if input_transform is not None: assert input_transform in ['resize_concat', 'multiple_select'] self.input_transform = input_transform self.in_index = in_index if input_transform is not None: assert isinstance(in_channels, (list, tuple)) assert isinstance(in_index, (list, tuple)) assert len(in_channels) == len(in_index) if input_transform == 'resize_concat': self.in_channels = sum(in_channels) else: self.in_channels = in_channels else: assert isinstance(in_channels, int) assert isinstance(in_index, int) self.in_channels = in_channels def _transform_inputs(self, inputs): """Transform inputs for decoder. Args: inputs (list[Tensor]): List of multi-level img features. Returns: Tensor: The transformed inputs """ if self.input_transform == 'resize_concat': inputs = [inputs[i] for i in self.in_index] upsampled_inputs = [ resize( input=x, size=inputs[0].shape[2:], mode='bilinear', align_corners=self.align_corners) for x in inputs ] inputs = torch.cat(upsampled_inputs, dim=1) elif self.input_transform == 'multiple_select': inputs = [inputs[i] for i in self.in_index] else: inputs = inputs[self.in_index] return inputs @auto_fp16() @abstractmethod def forward(self, inputs): """Placeholder of forward function.""" pass def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg): """Forward function for training. Args: inputs (list[Tensor]): List of multi-level img features. img_metas (list[dict]): List of image info dict where each dict has: 'img_shape', 'scale_factor', 'flip', and may also contain 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. For details on the values of these keys see `mmseg/datasets/pipelines/formatting.py:Collect`. gt_semantic_seg (Tensor): Semantic segmentation masks used if the architecture supports semantic segmentation task. train_cfg (dict): The training config. Returns: dict[str, Tensor]: a dictionary of loss components """ seg_logits = self(inputs) losses = self.losses(seg_logits, gt_semantic_seg) return losses def forward_test(self, inputs, img_metas, test_cfg): """Forward function for testing. Args: inputs (list[Tensor]): List of multi-level img features. img_metas (list[dict]): List of image info dict where each dict has: 'img_shape', 'scale_factor', 'flip', and may also contain 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. For details on the values of these keys see `mmseg/datasets/pipelines/formatting.py:Collect`. test_cfg (dict): The testing config. Returns: Tensor: Output segmentation map. """ return self.forward(inputs) def cls_seg(self, feat): """Classify each pixel.""" if self.dropout is not None: feat = self.dropout(feat) output = self.conv_seg(feat) return output class FPNHead(BaseDecodeHead): """Panoptic Feature Pyramid Networks. This head is the implementation of `Semantic FPN `_. Args: feature_strides (tuple[int]): The strides for input feature maps. stack_lateral. All strides suppose to be power of 2. The first one is of largest resolution. """ def __init__(self, feature_strides, **kwargs): super(FPNHead, self).__init__( input_transform='multiple_select', **kwargs) assert len(feature_strides) == len(self.in_channels) assert min(feature_strides) == feature_strides[0] self.feature_strides = feature_strides self.scale_heads = nn.ModuleList() for i in range(len(feature_strides)): head_length = max( 1, int(np.log2(feature_strides[i]) - np.log2(feature_strides[0]))) scale_head = [] for k in range(head_length): scale_head.append( ConvModule( self.in_channels[i] if k == 0 else self.channels, self.channels, 3, padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg)) if feature_strides[i] != feature_strides[0]: scale_head.append( Upsample( scale_factor=2, mode='bilinear', align_corners=self.align_corners)) self.scale_heads.append(nn.Sequential(*scale_head)) def forward(self, inputs): x = self._transform_inputs(inputs) output = self.scale_heads[0](x[0]) for i in range(1, len(self.feature_strides)): # non inplace output = output + resize( self.scale_heads[i](x[i]), size=output.shape[2:], mode='bilinear', align_corners=self.align_corners) output = self.cls_seg(output) return output class FPNSegmentor(nn.Module): ''' Packed Sementor Head Args: fpn_layer_indices: tuple of the indices of layers neck_cfg: dict of FPN params head_cfg: dict of FPNHead params ''' def __init__(self, fpn_layer_indices=(3, 5, 7, 11), neck_cfg=dict( in_channels=[768, 768, 768, 768], out_channels=256, num_outs=4), head_cfg=dict( in_channels=[256, 256, 256, 256], in_index=[0, 1, 2, 3], feature_strides=[4, 8, 16, 32], channels=128, dropout_ratio=0.1, num_classes=21, norm_cfg=dict(type='BN', requires_grad=True), align_corners=False, loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))): super(FPNSegmentor, self).__init__() self.fpn_layer_indices = fpn_layer_indices width = neck_cfg['in_channels'][0] self.pre_fpn = nn.ModuleList([ nn.Sequential( nn.ConvTranspose2d(width, width, kernel_size=2, stride=2), nn.BatchNorm2d(width), nn.GELU(), nn.ConvTranspose2d(width, width, kernel_size=2, stride=2), ), nn.ConvTranspose2d(width, width, kernel_size=2, stride=2), nn.Identity(), nn.MaxPool2d(kernel_size=2, stride=2) ]) self.fpn_neck = FPN(**neck_cfg) self.fpn_head = FPNHead(**head_cfg) # for vis self.NUM_CLASSES = head_cfg['num_classes'] state = np.random.get_state() np.random.seed(42) palette = np.random.randint(0, 255, size=(self.NUM_CLASSES, 3)) np.random.set_state(state) self.PALETTE = palette def show_result(self, img, result, palette=None, win_name='', show=False, wait_time=0, out_file=None, opacity=0.5): """Draw `result` over `img`. Args: img (str or Tensor): The image to be displayed. result (Tensor): The semantic segmentation results to draw over `img`. palette (list[list[int]]] | np.ndarray | None): The palette of segmentation map. If None is given, random palette will be generated. Default: None win_name (str): The window name. wait_time (int): Value of waitKey param. Default: 0. show (bool): Whether to show the image. Default: False. out_file (str or None): The filename to write the image. Default: None. opacity(float): Opacity of painted segmentation map. Default 0.5. Must be in (0, 1] range. Returns: img (Tensor): Only if not `show` or `out_file` """ img = mmcv.imread(img) img = img.copy() # seg = result[0] seg = result if palette is None: if self.PALETTE is None: # Get random state before set seed, # and restore random state later. # It will prevent loss of randomness, as the palette # may be different in each iteration if not specified. # See: https://github.com/open-mmlab/mmdetection/issues/5844 state = np.random.get_state() np.random.seed(42) # random palette palette = np.random.randint(0, 255, size=(self.NUM_CLASSES, 3)) np.random.set_state(state) else: palette = self.PALETTE palette = np.array(palette) assert palette.shape[0] == self.NUM_CLASSES assert palette.shape[1] == 3 assert len(palette.shape) == 2 assert 0 < opacity <= 1.0 assert seg.shape[1] == self.NUM_CLASSES if seg.shape[2] != img.shape[0] or seg.shape[3] != img.shape[1]: seg = resize(seg, (img.shape[0], img.shape[1]), None, 'bilinear', True) seg = seg[0] seg = torch.argmax(seg, dim=0) color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) for label, color in enumerate(palette): color_seg[seg == label, :] = color # convert to BGR color_seg = color_seg[..., ::-1] img = img * (1 - opacity) + color_seg * opacity img = img.astype(np.uint8) # if out_file specified, do not show image in window if out_file is not None: show = False if show: mmcv.imshow(img, win_name, wait_time) if out_file is not None: mmcv.imwrite(img, out_file) if not (show or out_file): warnings.warn('show==False and out_file is not specified, only ' 'result image will be returned') return def forward(self, x): x = [x[idx] for idx in self.fpn_layer_indices] x = [self.pre_fpn[i](x[i]) for i in range(len(x))] x = self.fpn_neck(x) x = self.fpn_head(x) return x