models.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. # The implementation is also open-sourced by the authors,
  2. # and available at https://github.com/alibaba-mmai-research/TAdaConv
  3. # Copyright 2021-2022 The Alibaba FVI Team Authors. All rights reserved.
  4. import torch.nn as nn
  5. from .s3dg import Inception3D
  6. from .tada_convnext import TadaConvNeXt
  7. class BaseVideoModel(nn.Module):
  8. """
  9. Standard video model.
  10. The model is divided into the backbone and the head, where the backbone
  11. extracts features and the head performs classification.
  12. The backbones can be defined in model/base/backbone.py or anywhere else
  13. as long as the backbone is registered by the BACKBONE_REGISTRY.
  14. The heads can be defined in model/module_zoo/heads/ or anywhere else
  15. as long as the head is registered by the HEAD_REGISTRY.
  16. The registries automatically finds the registered modules and construct
  17. the base video model.
  18. """
  19. def __init__(self, cfg):
  20. """
  21. Args:
  22. cfg (Config): global config object.
  23. """
  24. super(BaseVideoModel, self).__init__()
  25. # the backbone is created according to meta-architectures
  26. # defined in models/base/backbone.py
  27. if cfg.MODEL.NAME == 'ConvNeXt_tiny':
  28. self.backbone = TadaConvNeXt(cfg)
  29. elif cfg.MODEL.NAME == 'S3DG':
  30. self.backbone = Inception3D(cfg)
  31. else:
  32. error_str = 'backbone {} is not supported, ConvNeXt_tiny or S3DG is supported'.format(
  33. cfg.MODEL.NAME)
  34. raise NotImplementedError(error_str)
  35. # the head is created according to the heads
  36. # defined in models/module_zoo/heads
  37. if cfg.VIDEO.HEAD.NAME == 'BaseHead':
  38. self.head = BaseHead(cfg)
  39. elif cfg.VIDEO.HEAD.NAME == 'AvgHead':
  40. self.head = AvgHead(cfg)
  41. else:
  42. error_str = 'head {} is not supported, BaseHead or AvgHead is supported'.format(
  43. cfg.VIDEO.HEAD.NAME)
  44. raise NotImplementedError(error_str)
  45. def forward(self, x):
  46. x = self.backbone(x)
  47. x = self.head(x)
  48. return x
  49. class BaseHead(nn.Module):
  50. """
  51. Constructs base head.
  52. """
  53. def __init__(
  54. self,
  55. cfg,
  56. ):
  57. """
  58. Args:
  59. cfg (Config): global config object.
  60. """
  61. super(BaseHead, self).__init__()
  62. self.cfg = cfg
  63. dim = cfg.VIDEO.BACKBONE.NUM_OUT_FEATURES
  64. num_classes = cfg.VIDEO.HEAD.NUM_CLASSES
  65. dropout_rate = cfg.VIDEO.HEAD.DROPOUT_RATE
  66. activation_func = cfg.VIDEO.HEAD.ACTIVATION
  67. self._construct_head(dim, num_classes, dropout_rate, activation_func)
  68. def _construct_head(self, dim, num_classes, dropout_rate, activation_func):
  69. self.global_avg_pool = nn.AdaptiveAvgPool3d(1)
  70. if dropout_rate > 0.0:
  71. self.dropout = nn.Dropout(dropout_rate)
  72. self.out = nn.Linear(dim, num_classes, bias=True)
  73. if activation_func == 'softmax':
  74. self.activation = nn.Softmax(dim=-1)
  75. elif activation_func == 'sigmoid':
  76. self.activation = nn.Sigmoid()
  77. else:
  78. raise NotImplementedError('{} is not supported as an activation'
  79. 'function.'.format(activation_func))
  80. def forward(self, x):
  81. if len(x.shape) == 5:
  82. x = self.global_avg_pool(x)
  83. # (N, C, T, H, W) -> (N, T, H, W, C).
  84. x = x.permute((0, 2, 3, 4, 1))
  85. if hasattr(self, 'dropout'):
  86. out = self.dropout(x)
  87. else:
  88. out = x
  89. out = self.out(out)
  90. out = self.activation(out)
  91. out = out.view(out.shape[0], -1)
  92. return out, x.view(x.shape[0], -1)
  93. class AvgHead(nn.Module):
  94. """
  95. Constructs base head.
  96. """
  97. def __init__(
  98. self,
  99. cfg,
  100. ):
  101. """
  102. Args:
  103. cfg (Config): global config object.
  104. """
  105. super(AvgHead, self).__init__()
  106. self.cfg = cfg
  107. self.global_avg_pool = nn.AdaptiveAvgPool3d(1)
  108. def forward(self, x):
  109. if len(x.shape) == 5:
  110. x = self.global_avg_pool(x)
  111. # (N, C, T, H, W) -> (N, T, H, W, C).
  112. x = x.permute((0, 2, 3, 4, 1))
  113. out = x.view(x.shape[0], -1)
  114. return out, x.view(x.shape[0], -1)