modeling_resnet.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435
  1. # coding=utf-8
  2. # Copyright 2022 Microsoft Research, Inc. and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch ResNet model."""
  16. import math
  17. from typing import Optional
  18. import torch
  19. from torch import Tensor, nn
  20. from ...activations import ACT2FN
  21. from ...modeling_outputs import (
  22. BackboneOutput,
  23. BaseModelOutputWithNoAttention,
  24. BaseModelOutputWithPoolingAndNoAttention,
  25. ImageClassifierOutputWithNoAttention,
  26. )
  27. from ...modeling_utils import PreTrainedModel
  28. from ...utils import auto_docstring, logging
  29. from ...utils.backbone_utils import BackboneMixin
  30. from .configuration_resnet import ResNetConfig
  31. logger = logging.get_logger(__name__)
  32. class ResNetConvLayer(nn.Module):
  33. def __init__(
  34. self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1, activation: str = "relu"
  35. ):
  36. super().__init__()
  37. self.convolution = nn.Conv2d(
  38. in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=kernel_size // 2, bias=False
  39. )
  40. self.normalization = nn.BatchNorm2d(out_channels)
  41. self.activation = ACT2FN[activation] if activation is not None else nn.Identity()
  42. def forward(self, input: Tensor) -> Tensor:
  43. hidden_state = self.convolution(input)
  44. hidden_state = self.normalization(hidden_state)
  45. hidden_state = self.activation(hidden_state)
  46. return hidden_state
  47. class ResNetEmbeddings(nn.Module):
  48. """
  49. ResNet Embeddings (stem) composed of a single aggressive convolution.
  50. """
  51. def __init__(self, config: ResNetConfig):
  52. super().__init__()
  53. self.embedder = ResNetConvLayer(
  54. config.num_channels, config.embedding_size, kernel_size=7, stride=2, activation=config.hidden_act
  55. )
  56. self.pooler = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
  57. self.num_channels = config.num_channels
  58. def forward(self, pixel_values: Tensor) -> Tensor:
  59. num_channels = pixel_values.shape[1]
  60. if num_channels != self.num_channels:
  61. raise ValueError(
  62. "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
  63. )
  64. embedding = self.embedder(pixel_values)
  65. embedding = self.pooler(embedding)
  66. return embedding
  67. class ResNetShortCut(nn.Module):
  68. """
  69. ResNet shortcut, used to project the residual features to the correct size. If needed, it is also used to
  70. downsample the input using `stride=2`.
  71. """
  72. def __init__(self, in_channels: int, out_channels: int, stride: int = 2):
  73. super().__init__()
  74. self.convolution = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)
  75. self.normalization = nn.BatchNorm2d(out_channels)
  76. def forward(self, input: Tensor) -> Tensor:
  77. hidden_state = self.convolution(input)
  78. hidden_state = self.normalization(hidden_state)
  79. return hidden_state
  80. class ResNetBasicLayer(nn.Module):
  81. """
  82. A classic ResNet's residual layer composed by two `3x3` convolutions.
  83. """
  84. def __init__(self, in_channels: int, out_channels: int, stride: int = 1, activation: str = "relu"):
  85. super().__init__()
  86. should_apply_shortcut = in_channels != out_channels or stride != 1
  87. self.shortcut = (
  88. ResNetShortCut(in_channels, out_channels, stride=stride) if should_apply_shortcut else nn.Identity()
  89. )
  90. self.layer = nn.Sequential(
  91. ResNetConvLayer(in_channels, out_channels, stride=stride),
  92. ResNetConvLayer(out_channels, out_channels, activation=None),
  93. )
  94. self.activation = ACT2FN[activation]
  95. def forward(self, hidden_state):
  96. residual = hidden_state
  97. hidden_state = self.layer(hidden_state)
  98. residual = self.shortcut(residual)
  99. hidden_state += residual
  100. hidden_state = self.activation(hidden_state)
  101. return hidden_state
  102. class ResNetBottleNeckLayer(nn.Module):
  103. """
  104. A classic ResNet's bottleneck layer composed by three `3x3` convolutions.
  105. The first `1x1` convolution reduces the input by a factor of `reduction` in order to make the second `3x3`
  106. convolution faster. The last `1x1` convolution remaps the reduced features to `out_channels`. If
  107. `downsample_in_bottleneck` is true, downsample will be in the first layer instead of the second layer.
  108. """
  109. def __init__(
  110. self,
  111. in_channels: int,
  112. out_channels: int,
  113. stride: int = 1,
  114. activation: str = "relu",
  115. reduction: int = 4,
  116. downsample_in_bottleneck: bool = False,
  117. ):
  118. super().__init__()
  119. should_apply_shortcut = in_channels != out_channels or stride != 1
  120. reduces_channels = out_channels // reduction
  121. self.shortcut = (
  122. ResNetShortCut(in_channels, out_channels, stride=stride) if should_apply_shortcut else nn.Identity()
  123. )
  124. self.layer = nn.Sequential(
  125. ResNetConvLayer(
  126. in_channels, reduces_channels, kernel_size=1, stride=stride if downsample_in_bottleneck else 1
  127. ),
  128. ResNetConvLayer(reduces_channels, reduces_channels, stride=stride if not downsample_in_bottleneck else 1),
  129. ResNetConvLayer(reduces_channels, out_channels, kernel_size=1, activation=None),
  130. )
  131. self.activation = ACT2FN[activation]
  132. def forward(self, hidden_state):
  133. residual = hidden_state
  134. hidden_state = self.layer(hidden_state)
  135. residual = self.shortcut(residual)
  136. hidden_state += residual
  137. hidden_state = self.activation(hidden_state)
  138. return hidden_state
  139. class ResNetStage(nn.Module):
  140. """
  141. A ResNet stage composed by stacked layers.
  142. """
  143. def __init__(
  144. self,
  145. config: ResNetConfig,
  146. in_channels: int,
  147. out_channels: int,
  148. stride: int = 2,
  149. depth: int = 2,
  150. ):
  151. super().__init__()
  152. layer = ResNetBottleNeckLayer if config.layer_type == "bottleneck" else ResNetBasicLayer
  153. if config.layer_type == "bottleneck":
  154. first_layer = layer(
  155. in_channels,
  156. out_channels,
  157. stride=stride,
  158. activation=config.hidden_act,
  159. downsample_in_bottleneck=config.downsample_in_bottleneck,
  160. )
  161. else:
  162. first_layer = layer(in_channels, out_channels, stride=stride, activation=config.hidden_act)
  163. self.layers = nn.Sequential(
  164. first_layer, *[layer(out_channels, out_channels, activation=config.hidden_act) for _ in range(depth - 1)]
  165. )
  166. def forward(self, input: Tensor) -> Tensor:
  167. hidden_state = input
  168. for layer in self.layers:
  169. hidden_state = layer(hidden_state)
  170. return hidden_state
  171. class ResNetEncoder(nn.Module):
  172. def __init__(self, config: ResNetConfig):
  173. super().__init__()
  174. self.stages = nn.ModuleList([])
  175. # based on `downsample_in_first_stage` the first layer of the first stage may or may not downsample the input
  176. self.stages.append(
  177. ResNetStage(
  178. config,
  179. config.embedding_size,
  180. config.hidden_sizes[0],
  181. stride=2 if config.downsample_in_first_stage else 1,
  182. depth=config.depths[0],
  183. )
  184. )
  185. in_out_channels = zip(config.hidden_sizes, config.hidden_sizes[1:])
  186. for (in_channels, out_channels), depth in zip(in_out_channels, config.depths[1:]):
  187. self.stages.append(ResNetStage(config, in_channels, out_channels, depth=depth))
  188. def forward(
  189. self, hidden_state: Tensor, output_hidden_states: bool = False, return_dict: bool = True
  190. ) -> BaseModelOutputWithNoAttention:
  191. hidden_states = () if output_hidden_states else None
  192. for stage_module in self.stages:
  193. if output_hidden_states:
  194. hidden_states = hidden_states + (hidden_state,)
  195. hidden_state = stage_module(hidden_state)
  196. if output_hidden_states:
  197. hidden_states = hidden_states + (hidden_state,)
  198. if not return_dict:
  199. return tuple(v for v in [hidden_state, hidden_states] if v is not None)
  200. return BaseModelOutputWithNoAttention(
  201. last_hidden_state=hidden_state,
  202. hidden_states=hidden_states,
  203. )
  204. @auto_docstring
  205. class ResNetPreTrainedModel(PreTrainedModel):
  206. config: ResNetConfig
  207. base_model_prefix = "resnet"
  208. main_input_name = "pixel_values"
  209. _no_split_modules = ["ResNetConvLayer", "ResNetShortCut"]
  210. def _init_weights(self, module):
  211. if isinstance(module, nn.Conv2d):
  212. nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
  213. # copied from the `reset_parameters` method of `class Linear(Module)` in `torch`.
  214. elif isinstance(module, nn.Linear):
  215. nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5))
  216. if module.bias is not None:
  217. fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
  218. bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
  219. nn.init.uniform_(module.bias, -bound, bound)
  220. elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)):
  221. nn.init.constant_(module.weight, 1)
  222. nn.init.constant_(module.bias, 0)
  223. @auto_docstring
  224. class ResNetModel(ResNetPreTrainedModel):
  225. def __init__(self, config):
  226. super().__init__(config)
  227. self.config = config
  228. self.embedder = ResNetEmbeddings(config)
  229. self.encoder = ResNetEncoder(config)
  230. self.pooler = nn.AdaptiveAvgPool2d((1, 1))
  231. # Initialize weights and apply final processing
  232. self.post_init()
  233. @auto_docstring
  234. def forward(
  235. self, pixel_values: Tensor, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None
  236. ) -> BaseModelOutputWithPoolingAndNoAttention:
  237. output_hidden_states = (
  238. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  239. )
  240. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  241. embedding_output = self.embedder(pixel_values)
  242. encoder_outputs = self.encoder(
  243. embedding_output, output_hidden_states=output_hidden_states, return_dict=return_dict
  244. )
  245. last_hidden_state = encoder_outputs[0]
  246. pooled_output = self.pooler(last_hidden_state)
  247. if not return_dict:
  248. return (last_hidden_state, pooled_output) + encoder_outputs[1:]
  249. return BaseModelOutputWithPoolingAndNoAttention(
  250. last_hidden_state=last_hidden_state,
  251. pooler_output=pooled_output,
  252. hidden_states=encoder_outputs.hidden_states,
  253. )
  254. @auto_docstring(
  255. custom_intro="""
  256. ResNet Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
  257. ImageNet.
  258. """
  259. )
  260. class ResNetForImageClassification(ResNetPreTrainedModel):
  261. def __init__(self, config):
  262. super().__init__(config)
  263. self.num_labels = config.num_labels
  264. self.resnet = ResNetModel(config)
  265. # classification head
  266. self.classifier = nn.Sequential(
  267. nn.Flatten(),
  268. nn.Linear(config.hidden_sizes[-1], config.num_labels) if config.num_labels > 0 else nn.Identity(),
  269. )
  270. # initialize weights and apply final processing
  271. self.post_init()
  272. @auto_docstring
  273. def forward(
  274. self,
  275. pixel_values: Optional[torch.FloatTensor] = None,
  276. labels: Optional[torch.LongTensor] = None,
  277. output_hidden_states: Optional[bool] = None,
  278. return_dict: Optional[bool] = None,
  279. ) -> ImageClassifierOutputWithNoAttention:
  280. r"""
  281. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  282. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  283. config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  284. """
  285. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  286. outputs = self.resnet(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict)
  287. pooled_output = outputs.pooler_output if return_dict else outputs[1]
  288. logits = self.classifier(pooled_output)
  289. loss = None
  290. if labels is not None:
  291. loss = self.loss_function(labels, logits, self.config)
  292. if not return_dict:
  293. output = (logits,) + outputs[2:]
  294. return (loss,) + output if loss is not None else output
  295. return ImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states)
  296. @auto_docstring(
  297. custom_intro="""
  298. ResNet backbone, to be used with frameworks like DETR and MaskFormer.
  299. """
  300. )
  301. class ResNetBackbone(ResNetPreTrainedModel, BackboneMixin):
  302. has_attentions = False
  303. def __init__(self, config):
  304. super().__init__(config)
  305. super()._init_backbone(config)
  306. self.num_features = [config.embedding_size] + config.hidden_sizes
  307. self.embedder = ResNetEmbeddings(config)
  308. self.encoder = ResNetEncoder(config)
  309. # initialize weights and apply final processing
  310. self.post_init()
  311. @auto_docstring
  312. def forward(
  313. self, pixel_values: Tensor, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None
  314. ) -> BackboneOutput:
  315. r"""
  316. Examples:
  317. ```python
  318. >>> from transformers import AutoImageProcessor, AutoBackbone
  319. >>> import torch
  320. >>> from PIL import Image
  321. >>> import requests
  322. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  323. >>> image = Image.open(requests.get(url, stream=True).raw)
  324. >>> processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50")
  325. >>> model = AutoBackbone.from_pretrained(
  326. ... "microsoft/resnet-50", out_features=["stage1", "stage2", "stage3", "stage4"]
  327. ... )
  328. >>> inputs = processor(image, return_tensors="pt")
  329. >>> outputs = model(**inputs)
  330. >>> feature_maps = outputs.feature_maps
  331. >>> list(feature_maps[-1].shape)
  332. [1, 2048, 7, 7]
  333. ```"""
  334. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  335. output_hidden_states = (
  336. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  337. )
  338. embedding_output = self.embedder(pixel_values)
  339. outputs = self.encoder(embedding_output, output_hidden_states=True, return_dict=True)
  340. hidden_states = outputs.hidden_states
  341. feature_maps = ()
  342. for idx, stage in enumerate(self.stage_names):
  343. if stage in self.out_features:
  344. feature_maps += (hidden_states[idx],)
  345. if not return_dict:
  346. output = (feature_maps,)
  347. if output_hidden_states:
  348. output += (outputs.hidden_states,)
  349. return output
  350. return BackboneOutput(
  351. feature_maps=feature_maps,
  352. hidden_states=outputs.hidden_states if output_hidden_states else None,
  353. attentions=None,
  354. )
  355. __all__ = ["ResNetForImageClassification", "ResNetModel", "ResNetPreTrainedModel", "ResNetBackbone"]