modeling_regnet.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376
  1. # coding=utf-8
  2. # Copyright 2022 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch RegNet model."""
  16. import math
  17. from typing import Optional
  18. import torch
  19. from torch import Tensor, nn
  20. from ...activations import ACT2FN
  21. from ...modeling_outputs import (
  22. BaseModelOutputWithNoAttention,
  23. BaseModelOutputWithPoolingAndNoAttention,
  24. ImageClassifierOutputWithNoAttention,
  25. )
  26. from ...modeling_utils import PreTrainedModel
  27. from ...utils import auto_docstring, logging
  28. from .configuration_regnet import RegNetConfig
  29. logger = logging.get_logger(__name__)
  30. class RegNetConvLayer(nn.Module):
  31. def __init__(
  32. self,
  33. in_channels: int,
  34. out_channels: int,
  35. kernel_size: int = 3,
  36. stride: int = 1,
  37. groups: int = 1,
  38. activation: Optional[str] = "relu",
  39. ):
  40. super().__init__()
  41. self.convolution = nn.Conv2d(
  42. in_channels,
  43. out_channels,
  44. kernel_size=kernel_size,
  45. stride=stride,
  46. padding=kernel_size // 2,
  47. groups=groups,
  48. bias=False,
  49. )
  50. self.normalization = nn.BatchNorm2d(out_channels)
  51. self.activation = ACT2FN[activation] if activation is not None else nn.Identity()
  52. def forward(self, hidden_state):
  53. hidden_state = self.convolution(hidden_state)
  54. hidden_state = self.normalization(hidden_state)
  55. hidden_state = self.activation(hidden_state)
  56. return hidden_state
  57. class RegNetEmbeddings(nn.Module):
  58. """
  59. RegNet Embeddings (stem) composed of a single aggressive convolution.
  60. """
  61. def __init__(self, config: RegNetConfig):
  62. super().__init__()
  63. self.embedder = RegNetConvLayer(
  64. config.num_channels, config.embedding_size, kernel_size=3, stride=2, activation=config.hidden_act
  65. )
  66. self.num_channels = config.num_channels
  67. def forward(self, pixel_values):
  68. num_channels = pixel_values.shape[1]
  69. if num_channels != self.num_channels:
  70. raise ValueError(
  71. "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
  72. )
  73. hidden_state = self.embedder(pixel_values)
  74. return hidden_state
  75. # Copied from transformers.models.resnet.modeling_resnet.ResNetShortCut with ResNet->RegNet
  76. class RegNetShortCut(nn.Module):
  77. """
  78. RegNet shortcut, used to project the residual features to the correct size. If needed, it is also used to
  79. downsample the input using `stride=2`.
  80. """
  81. def __init__(self, in_channels: int, out_channels: int, stride: int = 2):
  82. super().__init__()
  83. self.convolution = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)
  84. self.normalization = nn.BatchNorm2d(out_channels)
  85. def forward(self, input: Tensor) -> Tensor:
  86. hidden_state = self.convolution(input)
  87. hidden_state = self.normalization(hidden_state)
  88. return hidden_state
  89. class RegNetSELayer(nn.Module):
  90. """
  91. Squeeze and Excitation layer (SE) proposed in [Squeeze-and-Excitation Networks](https://huggingface.co/papers/1709.01507).
  92. """
  93. def __init__(self, in_channels: int, reduced_channels: int):
  94. super().__init__()
  95. self.pooler = nn.AdaptiveAvgPool2d((1, 1))
  96. self.attention = nn.Sequential(
  97. nn.Conv2d(in_channels, reduced_channels, kernel_size=1),
  98. nn.ReLU(),
  99. nn.Conv2d(reduced_channels, in_channels, kernel_size=1),
  100. nn.Sigmoid(),
  101. )
  102. def forward(self, hidden_state):
  103. # b c h w -> b c 1 1
  104. pooled = self.pooler(hidden_state)
  105. attention = self.attention(pooled)
  106. hidden_state = hidden_state * attention
  107. return hidden_state
  108. class RegNetXLayer(nn.Module):
  109. """
  110. RegNet's layer composed by three `3x3` convolutions, same as a ResNet bottleneck layer with reduction = 1.
  111. """
  112. def __init__(self, config: RegNetConfig, in_channels: int, out_channels: int, stride: int = 1):
  113. super().__init__()
  114. should_apply_shortcut = in_channels != out_channels or stride != 1
  115. groups = max(1, out_channels // config.groups_width)
  116. self.shortcut = (
  117. RegNetShortCut(in_channels, out_channels, stride=stride) if should_apply_shortcut else nn.Identity()
  118. )
  119. self.layer = nn.Sequential(
  120. RegNetConvLayer(in_channels, out_channels, kernel_size=1, activation=config.hidden_act),
  121. RegNetConvLayer(out_channels, out_channels, stride=stride, groups=groups, activation=config.hidden_act),
  122. RegNetConvLayer(out_channels, out_channels, kernel_size=1, activation=None),
  123. )
  124. self.activation = ACT2FN[config.hidden_act]
  125. def forward(self, hidden_state):
  126. residual = hidden_state
  127. hidden_state = self.layer(hidden_state)
  128. residual = self.shortcut(residual)
  129. hidden_state += residual
  130. hidden_state = self.activation(hidden_state)
  131. return hidden_state
  132. class RegNetYLayer(nn.Module):
  133. """
  134. RegNet's Y layer: an X layer with Squeeze and Excitation.
  135. """
  136. def __init__(self, config: RegNetConfig, in_channels: int, out_channels: int, stride: int = 1):
  137. super().__init__()
  138. should_apply_shortcut = in_channels != out_channels or stride != 1
  139. groups = max(1, out_channels // config.groups_width)
  140. self.shortcut = (
  141. RegNetShortCut(in_channels, out_channels, stride=stride) if should_apply_shortcut else nn.Identity()
  142. )
  143. self.layer = nn.Sequential(
  144. RegNetConvLayer(in_channels, out_channels, kernel_size=1, activation=config.hidden_act),
  145. RegNetConvLayer(out_channels, out_channels, stride=stride, groups=groups, activation=config.hidden_act),
  146. RegNetSELayer(out_channels, reduced_channels=int(round(in_channels / 4))),
  147. RegNetConvLayer(out_channels, out_channels, kernel_size=1, activation=None),
  148. )
  149. self.activation = ACT2FN[config.hidden_act]
  150. def forward(self, hidden_state):
  151. residual = hidden_state
  152. hidden_state = self.layer(hidden_state)
  153. residual = self.shortcut(residual)
  154. hidden_state += residual
  155. hidden_state = self.activation(hidden_state)
  156. return hidden_state
  157. class RegNetStage(nn.Module):
  158. """
  159. A RegNet stage composed by stacked layers.
  160. """
  161. def __init__(
  162. self,
  163. config: RegNetConfig,
  164. in_channels: int,
  165. out_channels: int,
  166. stride: int = 2,
  167. depth: int = 2,
  168. ):
  169. super().__init__()
  170. layer = RegNetXLayer if config.layer_type == "x" else RegNetYLayer
  171. self.layers = nn.Sequential(
  172. # downsampling is done in the first layer with stride of 2
  173. layer(
  174. config,
  175. in_channels,
  176. out_channels,
  177. stride=stride,
  178. ),
  179. *[layer(config, out_channels, out_channels) for _ in range(depth - 1)],
  180. )
  181. def forward(self, hidden_state):
  182. hidden_state = self.layers(hidden_state)
  183. return hidden_state
  184. class RegNetEncoder(nn.Module):
  185. def __init__(self, config: RegNetConfig):
  186. super().__init__()
  187. self.stages = nn.ModuleList([])
  188. # based on `downsample_in_first_stage`, the first layer of the first stage may or may not downsample the input
  189. self.stages.append(
  190. RegNetStage(
  191. config,
  192. config.embedding_size,
  193. config.hidden_sizes[0],
  194. stride=2 if config.downsample_in_first_stage else 1,
  195. depth=config.depths[0],
  196. )
  197. )
  198. in_out_channels = zip(config.hidden_sizes, config.hidden_sizes[1:])
  199. for (in_channels, out_channels), depth in zip(in_out_channels, config.depths[1:]):
  200. self.stages.append(RegNetStage(config, in_channels, out_channels, depth=depth))
  201. def forward(
  202. self, hidden_state: Tensor, output_hidden_states: bool = False, return_dict: bool = True
  203. ) -> BaseModelOutputWithNoAttention:
  204. hidden_states = () if output_hidden_states else None
  205. for stage_module in self.stages:
  206. if output_hidden_states:
  207. hidden_states = hidden_states + (hidden_state,)
  208. hidden_state = stage_module(hidden_state)
  209. if output_hidden_states:
  210. hidden_states = hidden_states + (hidden_state,)
  211. if not return_dict:
  212. return tuple(v for v in [hidden_state, hidden_states] if v is not None)
  213. return BaseModelOutputWithNoAttention(last_hidden_state=hidden_state, hidden_states=hidden_states)
  214. @auto_docstring
  215. class RegNetPreTrainedModel(PreTrainedModel):
  216. config: RegNetConfig
  217. base_model_prefix = "regnet"
  218. main_input_name = "pixel_values"
  219. _no_split_modules = ["RegNetYLayer"]
  220. # Copied from transformers.models.resnet.modeling_resnet.ResNetPreTrainedModel._init_weights
  221. def _init_weights(self, module):
  222. if isinstance(module, nn.Conv2d):
  223. nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
  224. # copied from the `reset_parameters` method of `class Linear(Module)` in `torch`.
  225. elif isinstance(module, nn.Linear):
  226. nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5))
  227. if module.bias is not None:
  228. fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
  229. bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
  230. nn.init.uniform_(module.bias, -bound, bound)
  231. elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)):
  232. nn.init.constant_(module.weight, 1)
  233. nn.init.constant_(module.bias, 0)
  234. @auto_docstring
  235. # Copied from transformers.models.resnet.modeling_resnet.ResNetModel with RESNET->REGNET,ResNet->RegNet
  236. class RegNetModel(RegNetPreTrainedModel):
  237. def __init__(self, config):
  238. super().__init__(config)
  239. self.config = config
  240. self.embedder = RegNetEmbeddings(config)
  241. self.encoder = RegNetEncoder(config)
  242. self.pooler = nn.AdaptiveAvgPool2d((1, 1))
  243. # Initialize weights and apply final processing
  244. self.post_init()
  245. @auto_docstring
  246. def forward(
  247. self, pixel_values: Tensor, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None
  248. ) -> BaseModelOutputWithPoolingAndNoAttention:
  249. output_hidden_states = (
  250. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  251. )
  252. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  253. embedding_output = self.embedder(pixel_values)
  254. encoder_outputs = self.encoder(
  255. embedding_output, output_hidden_states=output_hidden_states, return_dict=return_dict
  256. )
  257. last_hidden_state = encoder_outputs[0]
  258. pooled_output = self.pooler(last_hidden_state)
  259. if not return_dict:
  260. return (last_hidden_state, pooled_output) + encoder_outputs[1:]
  261. return BaseModelOutputWithPoolingAndNoAttention(
  262. last_hidden_state=last_hidden_state,
  263. pooler_output=pooled_output,
  264. hidden_states=encoder_outputs.hidden_states,
  265. )
  266. @auto_docstring(
  267. custom_intro="""
  268. RegNet Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
  269. ImageNet.
  270. """
  271. )
  272. # Copied from transformers.models.resnet.modeling_resnet.ResNetForImageClassification with RESNET->REGNET,ResNet->RegNet,resnet->regnet
  273. class RegNetForImageClassification(RegNetPreTrainedModel):
  274. def __init__(self, config):
  275. super().__init__(config)
  276. self.num_labels = config.num_labels
  277. self.regnet = RegNetModel(config)
  278. # classification head
  279. self.classifier = nn.Sequential(
  280. nn.Flatten(),
  281. nn.Linear(config.hidden_sizes[-1], config.num_labels) if config.num_labels > 0 else nn.Identity(),
  282. )
  283. # initialize weights and apply final processing
  284. self.post_init()
  285. @auto_docstring
  286. def forward(
  287. self,
  288. pixel_values: Optional[torch.FloatTensor] = None,
  289. labels: Optional[torch.LongTensor] = None,
  290. output_hidden_states: Optional[bool] = None,
  291. return_dict: Optional[bool] = None,
  292. ) -> ImageClassifierOutputWithNoAttention:
  293. r"""
  294. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  295. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  296. config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  297. """
  298. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  299. outputs = self.regnet(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict)
  300. pooled_output = outputs.pooler_output if return_dict else outputs[1]
  301. logits = self.classifier(pooled_output)
  302. loss = None
  303. if labels is not None:
  304. loss = self.loss_function(labels, logits, self.config)
  305. if not return_dict:
  306. output = (logits,) + outputs[2:]
  307. return (loss,) + output if loss is not None else output
  308. return ImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states)
  309. __all__ = ["RegNetForImageClassification", "RegNetModel", "RegNetPreTrainedModel"]