head.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688
  1. # The implementation is adopted from mmsegmentation,
  2. # made publicly available under the Apache License, Version 2.0 at https://github.com/open-mmlab/mmsegmentation
  3. from abc import ABCMeta, abstractmethod
  4. import mmcv
  5. import numpy as np
  6. import torch
  7. import torch.nn as nn
  8. import torch.nn.functional as F
  9. from mmcv.cnn import ConvModule
  10. from mmcv.runner import BaseModule, auto_fp16, force_fp32
  11. # classification head
  12. class LinearClassifier(nn.Module):
  13. def __init__(self, in_channels, num_classes):
  14. super(LinearClassifier, self).__init__()
  15. self.classifier = nn.Linear(in_channels, num_classes)
  16. def forward(self, x):
  17. return self.classifier(x[-1])
  18. # segmentation head
  19. def resize(input,
  20. size=None,
  21. scale_factor=None,
  22. mode='nearest',
  23. align_corners=None,
  24. warning=True):
  25. return F.interpolate(input, size, scale_factor, mode, align_corners)
  26. class Upsample(nn.Module):
  27. def __init__(self,
  28. size=None,
  29. scale_factor=None,
  30. mode='nearest',
  31. align_corners=None):
  32. super(Upsample, self).__init__()
  33. self.size = size
  34. if isinstance(scale_factor, tuple):
  35. self.scale_factor = tuple(float(factor) for factor in scale_factor)
  36. else:
  37. self.scale_factor = float(scale_factor) if scale_factor else None
  38. self.mode = mode
  39. self.align_corners = align_corners
  40. def forward(self, x):
  41. if not self.size:
  42. size = [int(t * self.scale_factor) for t in x.shape[-2:]]
  43. else:
  44. size = self.size
  45. return resize(x, size, None, self.mode, self.align_corners)
  46. class FPN(BaseModule):
  47. """Feature Pyramid Network.
  48. This neck is the implementation of `Feature Pyramid Networks for Object
  49. Detection <https://arxiv.org/abs/1612.03144>`_.
  50. Args:
  51. in_channels (list[int]): Number of input channels per scale.
  52. out_channels (int): Number of output channels (used at each scale).
  53. num_outs (int): Number of output scales.
  54. start_level (int): Index of the start input backbone level used to
  55. build the feature pyramid. Default: 0.
  56. end_level (int): Index of the end input backbone level (exclusive) to
  57. build the feature pyramid. Default: -1, which means the last level.
  58. add_extra_convs (bool | str): If bool, it decides whether to add conv
  59. layers on top of the original feature maps. Default to False.
  60. If True, its actual mode is specified by `extra_convs_on_inputs`.
  61. If str, it specifies the source feature map of the extra convs.
  62. Only the following options are allowed
  63. - 'on_input': Last feat map of neck inputs (i.e. backbone feature).
  64. - 'on_lateral': Last feature map after lateral convs.
  65. - 'on_output': The last output feature map after fpn convs.
  66. extra_convs_on_inputs (bool, deprecated): Whether to apply extra convs
  67. on the original feature from the backbone. If True,
  68. it is equivalent to `add_extra_convs='on_input'`. If False, it is
  69. equivalent to set `add_extra_convs='on_output'`. Default to True.
  70. relu_before_extra_convs (bool): Whether to apply relu before the extra
  71. conv. Default: False.
  72. no_norm_on_lateral (bool): Whether to apply norm on lateral.
  73. Default: False.
  74. conv_cfg (dict): Config dict for convolution layer. Default: None.
  75. norm_cfg (dict): Config dict for normalization layer. Default: None.
  76. act_cfg (dict): Config dict for activation layer in ConvModule.
  77. Default: None.
  78. upsample_cfg (dict): Config dict for interpolate layer.
  79. Default: dict(mode='nearest').
  80. init_cfg (dict or list[dict], optional): Initialization config dict.
  81. Example:
  82. >>> import torch
  83. >>> in_channels = [2, 3, 5, 7]
  84. >>> scales = [340, 170, 84, 43]
  85. >>> inputs = [torch.rand(1, c, s, s)
  86. ... for c, s in zip(in_channels, scales)]
  87. >>> self = FPN(in_channels, 11, len(in_channels)).eval()
  88. >>> outputs = self.forward(inputs)
  89. >>> for i in range(len(outputs)):
  90. ... print(f'outputs[{i}].shape = {outputs[i].shape}')
  91. outputs[0].shape = torch.Size([1, 11, 340, 340])
  92. outputs[1].shape = torch.Size([1, 11, 170, 170])
  93. outputs[2].shape = torch.Size([1, 11, 84, 84])
  94. outputs[3].shape = torch.Size([1, 11, 43, 43])
  95. """
  96. def __init__(self,
  97. in_channels,
  98. out_channels,
  99. num_outs,
  100. start_level=0,
  101. end_level=-1,
  102. add_extra_convs=False,
  103. extra_convs_on_inputs=False,
  104. relu_before_extra_convs=False,
  105. no_norm_on_lateral=False,
  106. conv_cfg=None,
  107. norm_cfg=None,
  108. act_cfg=None,
  109. upsample_cfg=dict(mode='nearest'),
  110. init_cfg=dict(
  111. type='Xavier', layer='Conv2d', distribution='uniform')):
  112. super(FPN, self).__init__(init_cfg)
  113. assert isinstance(in_channels, list)
  114. self.in_channels = in_channels
  115. self.out_channels = out_channels
  116. self.num_ins = len(in_channels)
  117. self.num_outs = num_outs
  118. self.relu_before_extra_convs = relu_before_extra_convs
  119. self.no_norm_on_lateral = no_norm_on_lateral
  120. self.fp16_enabled = False
  121. self.upsample_cfg = upsample_cfg.copy()
  122. if end_level == -1:
  123. self.backbone_end_level = self.num_ins
  124. assert num_outs >= self.num_ins - start_level
  125. else:
  126. # if end_level < inputs, no extra level is allowed
  127. self.backbone_end_level = end_level
  128. assert end_level <= len(in_channels)
  129. assert num_outs == end_level - start_level
  130. self.start_level = start_level
  131. self.end_level = end_level
  132. self.add_extra_convs = add_extra_convs
  133. assert isinstance(add_extra_convs, (str, bool))
  134. if isinstance(add_extra_convs, str):
  135. # Extra_convs_source choices: 'on_input', 'on_lateral', 'on_output'
  136. assert add_extra_convs in ('on_input', 'on_lateral', 'on_output')
  137. elif add_extra_convs: # True
  138. if extra_convs_on_inputs:
  139. # For compatibility with previous release
  140. # TODO: deprecate `extra_convs_on_inputs`
  141. self.add_extra_convs = 'on_input'
  142. else:
  143. self.add_extra_convs = 'on_output'
  144. self.lateral_convs = nn.ModuleList()
  145. self.fpn_convs = nn.ModuleList()
  146. for i in range(self.start_level, self.backbone_end_level):
  147. l_conv = ConvModule(
  148. in_channels[i],
  149. out_channels,
  150. 1,
  151. conv_cfg=conv_cfg,
  152. norm_cfg=norm_cfg if not self.no_norm_on_lateral else None,
  153. act_cfg=act_cfg,
  154. inplace=False)
  155. fpn_conv = ConvModule(
  156. out_channels,
  157. out_channels,
  158. 3,
  159. padding=1,
  160. conv_cfg=conv_cfg,
  161. norm_cfg=norm_cfg,
  162. act_cfg=act_cfg,
  163. inplace=False)
  164. self.lateral_convs.append(l_conv)
  165. self.fpn_convs.append(fpn_conv)
  166. # add extra conv layers (e.g., RetinaNet)
  167. extra_levels = num_outs - self.backbone_end_level + self.start_level
  168. if self.add_extra_convs and extra_levels >= 1:
  169. for i in range(extra_levels):
  170. if i == 0 and self.add_extra_convs == 'on_input':
  171. in_channels = self.in_channels[self.backbone_end_level - 1]
  172. else:
  173. in_channels = out_channels
  174. extra_fpn_conv = ConvModule(
  175. in_channels,
  176. out_channels,
  177. 3,
  178. stride=2,
  179. padding=1,
  180. conv_cfg=conv_cfg,
  181. norm_cfg=norm_cfg,
  182. act_cfg=act_cfg,
  183. inplace=False)
  184. self.fpn_convs.append(extra_fpn_conv)
  185. @auto_fp16()
  186. def forward(self, inputs):
  187. assert len(inputs) == len(self.in_channels)
  188. # build laterals
  189. laterals = [
  190. lateral_conv(inputs[i + self.start_level])
  191. for i, lateral_conv in enumerate(self.lateral_convs)
  192. ]
  193. # build top-down path
  194. used_backbone_levels = len(laterals)
  195. for i in range(used_backbone_levels - 1, 0, -1):
  196. # In some cases, fixing `scale factor` (e.g. 2) is preferred, but
  197. # it cannot co-exist with `size` in `F.interpolate`.
  198. if 'scale_factor' in self.upsample_cfg:
  199. laterals[i - 1] = laterals[i - 1] + resize(
  200. laterals[i], **self.upsample_cfg)
  201. else:
  202. prev_shape = laterals[i - 1].shape[2:]
  203. laterals[i - 1] = laterals[i - 1] + resize(
  204. laterals[i], size=prev_shape, **self.upsample_cfg)
  205. # build outputs
  206. # part 1: from original levels
  207. outs = [
  208. self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels)
  209. ]
  210. # part 2: add extra levels
  211. if self.num_outs > len(outs):
  212. # use max pool to get more levels on top of outputs
  213. # (e.g., Faster R-CNN, Mask R-CNN)
  214. if not self.add_extra_convs:
  215. for i in range(self.num_outs - used_backbone_levels):
  216. outs.append(F.max_pool2d(outs[-1], 1, stride=2))
  217. # add conv layers on top of original feature maps (RetinaNet)
  218. else:
  219. if self.add_extra_convs == 'on_input':
  220. extra_source = inputs[self.backbone_end_level - 1]
  221. elif self.add_extra_convs == 'on_lateral':
  222. extra_source = laterals[-1]
  223. elif self.add_extra_convs == 'on_output':
  224. extra_source = outs[-1]
  225. else:
  226. raise NotImplementedError
  227. outs.append(self.fpn_convs[used_backbone_levels](extra_source))
  228. for i in range(used_backbone_levels + 1, self.num_outs):
  229. if self.relu_before_extra_convs:
  230. outs.append(self.fpn_convs[i](F.relu(outs[-1])))
  231. else:
  232. outs.append(self.fpn_convs[i](outs[-1]))
  233. return tuple(outs)
  234. class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
  235. """Base class for BaseDecodeHead.
  236. Args:
  237. in_channels (int|Sequence[int]): Input channels.
  238. channels (int): Channels after modules, before conv_seg.
  239. num_classes (int): Number of classes.
  240. out_channels (int): Output channels of conv_seg.
  241. threshold (float): Threshold for binary segmentation in the case of
  242. `out_channels==1`. Default: None.
  243. dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
  244. conv_cfg (dict|None): Config of conv layers. Default: None.
  245. norm_cfg (dict|None): Config of norm layers. Default: None.
  246. act_cfg (dict): Config of activation layers.
  247. Default: dict(type='ReLU')
  248. in_index (int|Sequence[int]): Input feature index. Default: -1
  249. input_transform (str|None): Transformation type of input features.
  250. Options: 'resize_concat', 'multiple_select', None.
  251. 'resize_concat': Multiple feature maps will be resize to the
  252. same size as first one and than concat together.
  253. Usually used in FCN head of HRNet.
  254. 'multiple_select': Multiple feature maps will be bundle into
  255. a list and passed into decode head.
  256. None: Only one select feature map is allowed.
  257. Default: None.
  258. loss_decode (dict | Sequence[dict]): Config of decode loss.
  259. The `loss_name` is property of corresponding loss function which
  260. could be shown in training log. If you want this loss
  261. item to be included into the backward graph, `loss_` must be the
  262. prefix of the name. Defaults to 'loss_ce'.
  263. e.g. dict(type='CrossEntropyLoss'),
  264. [dict(type='CrossEntropyLoss', loss_name='loss_ce'),
  265. dict(type='DiceLoss', loss_name='loss_dice')]
  266. Default: dict(type='CrossEntropyLoss').
  267. ignore_index (int | None): The label index to be ignored. When using
  268. masked BCE loss, ignore_index should be set to None. Default: 255.
  269. sampler (dict|None): The config of segmentation map sampler.
  270. Default: None.
  271. align_corners (bool): align_corners argument of F.interpolate.
  272. Default: False.
  273. init_cfg (dict or list[dict], optional): Initialization config dict.
  274. """
  275. def __init__(self,
  276. in_channels,
  277. channels,
  278. *,
  279. num_classes,
  280. out_channels=None,
  281. threshold=None,
  282. dropout_ratio=0.1,
  283. conv_cfg=None,
  284. norm_cfg=None,
  285. act_cfg=dict(type='ReLU'),
  286. in_index=-1,
  287. input_transform=None,
  288. loss_decode=dict(
  289. type='CrossEntropyLoss',
  290. use_sigmoid=False,
  291. loss_weight=1.0),
  292. ignore_index=255,
  293. sampler=None,
  294. align_corners=False,
  295. init_cfg=dict(
  296. type='Normal', std=0.01, override=dict(name='conv_seg'))):
  297. super(BaseDecodeHead, self).__init__(init_cfg)
  298. self._init_inputs(in_channels, in_index, input_transform)
  299. self.channels = channels
  300. self.dropout_ratio = dropout_ratio
  301. self.conv_cfg = conv_cfg
  302. self.norm_cfg = norm_cfg
  303. self.act_cfg = act_cfg
  304. self.in_index = in_index
  305. self.ignore_index = ignore_index
  306. self.align_corners = align_corners
  307. if out_channels is None:
  308. if num_classes == 2:
  309. warnings.warn('For binary segmentation, we suggest using'
  310. '`out_channels = 1` to define the output'
  311. 'channels of segmentor, and use `threshold`'
  312. 'to convert seg_logist into a prediction'
  313. 'applying a threshold')
  314. out_channels = num_classes
  315. if out_channels != num_classes and out_channels != 1:
  316. raise ValueError(
  317. 'out_channels should be equal to num_classes,'
  318. 'except binary segmentation set out_channels == 1 and'
  319. f'num_classes == 2, but got out_channels={out_channels}'
  320. f'and num_classes={num_classes}')
  321. if out_channels == 1 and threshold is None:
  322. threshold = 0.3
  323. warnings.warn('threshold is not defined for binary, and defaults'
  324. 'to 0.3')
  325. self.num_classes = num_classes
  326. self.out_channels = out_channels
  327. self.threshold = threshold
  328. self.conv_seg = nn.Conv2d(channels, self.out_channels, kernel_size=1)
  329. if dropout_ratio > 0:
  330. self.dropout = nn.Dropout2d(dropout_ratio)
  331. else:
  332. self.dropout = None
  333. self.fp16_enabled = False
  334. def extra_repr(self):
  335. """Extra repr."""
  336. s = f'input_transform={self.input_transform}, ' \
  337. f'ignore_index={self.ignore_index}, ' \
  338. f'align_corners={self.align_corners}'
  339. return s
  340. def _init_inputs(self, in_channels, in_index, input_transform):
  341. """Check and initialize input transforms.
  342. The in_channels, in_index and input_transform must match.
  343. Specifically, when input_transform is None, only single feature map
  344. will be selected. So in_channels and in_index must be of type int.
  345. When input_transform
  346. Args:
  347. in_channels (int|Sequence[int]): Input channels.
  348. in_index (int|Sequence[int]): Input feature index.
  349. input_transform (str|None): Transformation type of input features.
  350. Options: 'resize_concat', 'multiple_select', None.
  351. 'resize_concat': Multiple feature maps will be resize to the
  352. same size as first one and than concat together.
  353. Usually used in FCN head of HRNet.
  354. 'multiple_select': Multiple feature maps will be bundle into
  355. a list and passed into decode head.
  356. None: Only one select feature map is allowed.
  357. """
  358. if input_transform is not None:
  359. assert input_transform in ['resize_concat', 'multiple_select']
  360. self.input_transform = input_transform
  361. self.in_index = in_index
  362. if input_transform is not None:
  363. assert isinstance(in_channels, (list, tuple))
  364. assert isinstance(in_index, (list, tuple))
  365. assert len(in_channels) == len(in_index)
  366. if input_transform == 'resize_concat':
  367. self.in_channels = sum(in_channels)
  368. else:
  369. self.in_channels = in_channels
  370. else:
  371. assert isinstance(in_channels, int)
  372. assert isinstance(in_index, int)
  373. self.in_channels = in_channels
  374. def _transform_inputs(self, inputs):
  375. """Transform inputs for decoder.
  376. Args:
  377. inputs (list[Tensor]): List of multi-level img features.
  378. Returns:
  379. Tensor: The transformed inputs
  380. """
  381. if self.input_transform == 'resize_concat':
  382. inputs = [inputs[i] for i in self.in_index]
  383. upsampled_inputs = [
  384. resize(
  385. input=x,
  386. size=inputs[0].shape[2:],
  387. mode='bilinear',
  388. align_corners=self.align_corners) for x in inputs
  389. ]
  390. inputs = torch.cat(upsampled_inputs, dim=1)
  391. elif self.input_transform == 'multiple_select':
  392. inputs = [inputs[i] for i in self.in_index]
  393. else:
  394. inputs = inputs[self.in_index]
  395. return inputs
  396. @auto_fp16()
  397. @abstractmethod
  398. def forward(self, inputs):
  399. """Placeholder of forward function."""
  400. pass
  401. def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg):
  402. """Forward function for training.
  403. Args:
  404. inputs (list[Tensor]): List of multi-level img features.
  405. img_metas (list[dict]): List of image info dict where each dict
  406. has: 'img_shape', 'scale_factor', 'flip', and may also contain
  407. 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
  408. For details on the values of these keys see
  409. `mmseg/datasets/pipelines/formatting.py:Collect`.
  410. gt_semantic_seg (Tensor): Semantic segmentation masks
  411. used if the architecture supports semantic segmentation task.
  412. train_cfg (dict): The training config.
  413. Returns:
  414. dict[str, Tensor]: a dictionary of loss components
  415. """
  416. seg_logits = self(inputs)
  417. losses = self.losses(seg_logits, gt_semantic_seg)
  418. return losses
  419. def forward_test(self, inputs, img_metas, test_cfg):
  420. """Forward function for testing.
  421. Args:
  422. inputs (list[Tensor]): List of multi-level img features.
  423. img_metas (list[dict]): List of image info dict where each dict
  424. has: 'img_shape', 'scale_factor', 'flip', and may also contain
  425. 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
  426. For details on the values of these keys see
  427. `mmseg/datasets/pipelines/formatting.py:Collect`.
  428. test_cfg (dict): The testing config.
  429. Returns:
  430. Tensor: Output segmentation map.
  431. """
  432. return self.forward(inputs)
  433. def cls_seg(self, feat):
  434. """Classify each pixel."""
  435. if self.dropout is not None:
  436. feat = self.dropout(feat)
  437. output = self.conv_seg(feat)
  438. return output
  439. class FPNHead(BaseDecodeHead):
  440. """Panoptic Feature Pyramid Networks.
  441. This head is the implementation of `Semantic FPN
  442. <https://arxiv.org/abs/1901.02446>`_.
  443. Args:
  444. feature_strides (tuple[int]): The strides for input feature maps.
  445. stack_lateral. All strides suppose to be power of 2. The first
  446. one is of largest resolution.
  447. """
  448. def __init__(self, feature_strides, **kwargs):
  449. super(FPNHead, self).__init__(
  450. input_transform='multiple_select', **kwargs)
  451. assert len(feature_strides) == len(self.in_channels)
  452. assert min(feature_strides) == feature_strides[0]
  453. self.feature_strides = feature_strides
  454. self.scale_heads = nn.ModuleList()
  455. for i in range(len(feature_strides)):
  456. head_length = max(
  457. 1,
  458. int(np.log2(feature_strides[i]) - np.log2(feature_strides[0])))
  459. scale_head = []
  460. for k in range(head_length):
  461. scale_head.append(
  462. ConvModule(
  463. self.in_channels[i] if k == 0 else self.channels,
  464. self.channels,
  465. 3,
  466. padding=1,
  467. conv_cfg=self.conv_cfg,
  468. norm_cfg=self.norm_cfg,
  469. act_cfg=self.act_cfg))
  470. if feature_strides[i] != feature_strides[0]:
  471. scale_head.append(
  472. Upsample(
  473. scale_factor=2,
  474. mode='bilinear',
  475. align_corners=self.align_corners))
  476. self.scale_heads.append(nn.Sequential(*scale_head))
  477. def forward(self, inputs):
  478. x = self._transform_inputs(inputs)
  479. output = self.scale_heads[0](x[0])
  480. for i in range(1, len(self.feature_strides)):
  481. # non inplace
  482. output = output + resize(
  483. self.scale_heads[i](x[i]),
  484. size=output.shape[2:],
  485. mode='bilinear',
  486. align_corners=self.align_corners)
  487. output = self.cls_seg(output)
  488. return output
  489. class FPNSegmentor(nn.Module):
  490. '''
  491. Packed Sementor Head
  492. Args:
  493. fpn_layer_indices: tuple of the indices of layers
  494. neck_cfg: dict of FPN params
  495. head_cfg: dict of FPNHead params
  496. '''
  497. def __init__(self,
  498. fpn_layer_indices=(3, 5, 7, 11),
  499. neck_cfg=dict(
  500. in_channels=[768, 768, 768, 768],
  501. out_channels=256,
  502. num_outs=4),
  503. head_cfg=dict(
  504. in_channels=[256, 256, 256, 256],
  505. in_index=[0, 1, 2, 3],
  506. feature_strides=[4, 8, 16, 32],
  507. channels=128,
  508. dropout_ratio=0.1,
  509. num_classes=21,
  510. norm_cfg=dict(type='BN', requires_grad=True),
  511. align_corners=False,
  512. loss_decode=dict(
  513. type='CrossEntropyLoss',
  514. use_sigmoid=False,
  515. loss_weight=1.0))):
  516. super(FPNSegmentor, self).__init__()
  517. self.fpn_layer_indices = fpn_layer_indices
  518. width = neck_cfg['in_channels'][0]
  519. self.pre_fpn = nn.ModuleList([
  520. nn.Sequential(
  521. nn.ConvTranspose2d(width, width, kernel_size=2, stride=2),
  522. nn.BatchNorm2d(width),
  523. nn.GELU(),
  524. nn.ConvTranspose2d(width, width, kernel_size=2, stride=2),
  525. ),
  526. nn.ConvTranspose2d(width, width, kernel_size=2, stride=2),
  527. nn.Identity(),
  528. nn.MaxPool2d(kernel_size=2, stride=2)
  529. ])
  530. self.fpn_neck = FPN(**neck_cfg)
  531. self.fpn_head = FPNHead(**head_cfg)
  532. # for vis
  533. self.NUM_CLASSES = head_cfg['num_classes']
  534. state = np.random.get_state()
  535. np.random.seed(42)
  536. palette = np.random.randint(0, 255, size=(self.NUM_CLASSES, 3))
  537. np.random.set_state(state)
  538. self.PALETTE = palette
  539. def show_result(self,
  540. img,
  541. result,
  542. palette=None,
  543. win_name='',
  544. show=False,
  545. wait_time=0,
  546. out_file=None,
  547. opacity=0.5):
  548. """Draw `result` over `img`.
  549. Args:
  550. img (str or Tensor): The image to be displayed.
  551. result (Tensor): The semantic segmentation results to draw over
  552. `img`.
  553. palette (list[list[int]]] | np.ndarray | None): The palette of
  554. segmentation map. If None is given, random palette will be
  555. generated. Default: None
  556. win_name (str): The window name.
  557. wait_time (int): Value of waitKey param.
  558. Default: 0.
  559. show (bool): Whether to show the image.
  560. Default: False.
  561. out_file (str or None): The filename to write the image.
  562. Default: None.
  563. opacity(float): Opacity of painted segmentation map.
  564. Default 0.5.
  565. Must be in (0, 1] range.
  566. Returns:
  567. img (Tensor): Only if not `show` or `out_file`
  568. """
  569. img = mmcv.imread(img)
  570. img = img.copy()
  571. # seg = result[0]
  572. seg = result
  573. if palette is None:
  574. if self.PALETTE is None:
  575. # Get random state before set seed,
  576. # and restore random state later.
  577. # It will prevent loss of randomness, as the palette
  578. # may be different in each iteration if not specified.
  579. # See: https://github.com/open-mmlab/mmdetection/issues/5844
  580. state = np.random.get_state()
  581. np.random.seed(42)
  582. # random palette
  583. palette = np.random.randint(0, 255, size=(self.NUM_CLASSES, 3))
  584. np.random.set_state(state)
  585. else:
  586. palette = self.PALETTE
  587. palette = np.array(palette)
  588. assert palette.shape[0] == self.NUM_CLASSES
  589. assert palette.shape[1] == 3
  590. assert len(palette.shape) == 2
  591. assert 0 < opacity <= 1.0
  592. assert seg.shape[1] == self.NUM_CLASSES
  593. if seg.shape[2] != img.shape[0] or seg.shape[3] != img.shape[1]:
  594. seg = resize(seg, (img.shape[0], img.shape[1]), None, 'bilinear',
  595. True)
  596. seg = seg[0]
  597. seg = torch.argmax(seg, dim=0)
  598. color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
  599. for label, color in enumerate(palette):
  600. color_seg[seg == label, :] = color
  601. # convert to BGR
  602. color_seg = color_seg[..., ::-1]
  603. img = img * (1 - opacity) + color_seg * opacity
  604. img = img.astype(np.uint8)
  605. # if out_file specified, do not show image in window
  606. if out_file is not None:
  607. show = False
  608. if show:
  609. mmcv.imshow(img, win_name, wait_time)
  610. if out_file is not None:
  611. mmcv.imwrite(img, out_file)
  612. if not (show or out_file):
  613. warnings.warn('show==False and out_file is not specified, only '
  614. 'result image will be returned')
  615. return
  616. def forward(self, x):
  617. x = [x[idx] for idx in self.fpn_layer_indices]
  618. x = [self.pre_fpn[i](x[i]) for i in range(len(x))]
  619. x = self.fpn_neck(x)
  620. x = self.fpn_head(x)
  621. return x