modeling_convnext.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424
  1. # coding=utf-8
  2. # Copyright 2022 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch ConvNext model."""
  16. from typing import Optional
  17. import torch
  18. from torch import nn
  19. from ...activations import ACT2FN
  20. from ...modeling_outputs import (
  21. BackboneOutput,
  22. BaseModelOutputWithNoAttention,
  23. BaseModelOutputWithPoolingAndNoAttention,
  24. ImageClassifierOutputWithNoAttention,
  25. )
  26. from ...modeling_utils import PreTrainedModel
  27. from ...utils import auto_docstring, logging
  28. from ...utils.backbone_utils import BackboneMixin
  29. from ...utils.generic import can_return_tuple
  30. from .configuration_convnext import ConvNextConfig
  31. logger = logging.get_logger(__name__)
  32. # Copied from transformers.models.beit.modeling_beit.drop_path
  33. def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
  34. """
  35. Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  36. Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
  37. however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
  38. See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
  39. layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
  40. argument.
  41. """
  42. if drop_prob == 0.0 or not training:
  43. return input
  44. keep_prob = 1 - drop_prob
  45. shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
  46. random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
  47. random_tensor.floor_() # binarize
  48. output = input.div(keep_prob) * random_tensor
  49. return output
  50. # Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->ConvNext
  51. class ConvNextDropPath(nn.Module):
  52. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
  53. def __init__(self, drop_prob: Optional[float] = None) -> None:
  54. super().__init__()
  55. self.drop_prob = drop_prob
  56. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  57. return drop_path(hidden_states, self.drop_prob, self.training)
  58. def extra_repr(self) -> str:
  59. return f"p={self.drop_prob}"
  60. class ConvNextLayerNorm(nn.LayerNorm):
  61. r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
  62. The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height,
  63. width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width).
  64. """
  65. def __init__(self, normalized_shape, *, eps=1e-6, data_format="channels_last", **kwargs):
  66. super().__init__(normalized_shape, eps=eps, **kwargs)
  67. if data_format not in ["channels_last", "channels_first"]:
  68. raise NotImplementedError(f"Unsupported data format: {data_format}")
  69. self.data_format = data_format
  70. def forward(self, features: torch.Tensor) -> torch.Tensor:
  71. """
  72. Args:
  73. features: Tensor of shape (batch_size, channels, height, width) OR (batch_size, height, width, channels)
  74. """
  75. if self.data_format == "channels_first":
  76. features = features.permute(0, 2, 3, 1)
  77. features = super().forward(features)
  78. features = features.permute(0, 3, 1, 2)
  79. else:
  80. features = super().forward(features)
  81. return features
  82. class ConvNextEmbeddings(nn.Module):
  83. """This class is comparable to (and inspired by) the SwinEmbeddings class
  84. found in src/transformers/models/swin/modeling_swin.py.
  85. """
  86. def __init__(self, config):
  87. super().__init__()
  88. self.patch_embeddings = nn.Conv2d(
  89. config.num_channels, config.hidden_sizes[0], kernel_size=config.patch_size, stride=config.patch_size
  90. )
  91. self.layernorm = ConvNextLayerNorm(config.hidden_sizes[0], eps=1e-6, data_format="channels_first")
  92. self.num_channels = config.num_channels
  93. def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
  94. num_channels = pixel_values.shape[1]
  95. if num_channels != self.num_channels:
  96. raise ValueError(
  97. "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
  98. )
  99. embeddings = self.patch_embeddings(pixel_values)
  100. embeddings = self.layernorm(embeddings)
  101. return embeddings
  102. class ConvNextLayer(nn.Module):
  103. """This corresponds to the `Block` class in the original implementation.
  104. There are two equivalent implementations: [DwConv, LayerNorm (channels_first), Conv, GELU,1x1 Conv]; all in (N, C,
  105. H, W) (2) [DwConv, Permute to (N, H, W, C), LayerNorm (channels_last), Linear, GELU, Linear]; Permute back
  106. The authors used (2) as they find it slightly faster in PyTorch.
  107. Args:
  108. config ([`ConvNextConfig`]): Model configuration class.
  109. dim (`int`): Number of input channels.
  110. drop_path (`float`): Stochastic depth rate. Default: 0.0.
  111. """
  112. def __init__(self, config, dim, drop_path=0):
  113. super().__init__()
  114. self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
  115. self.layernorm = ConvNextLayerNorm(dim, eps=1e-6)
  116. self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
  117. self.act = ACT2FN[config.hidden_act]
  118. self.pwconv2 = nn.Linear(4 * dim, dim)
  119. self.layer_scale_parameter = (
  120. nn.Parameter(config.layer_scale_init_value * torch.ones(dim), requires_grad=True)
  121. if config.layer_scale_init_value > 0
  122. else None
  123. )
  124. self.drop_path = ConvNextDropPath(drop_path) if drop_path > 0.0 else nn.Identity()
  125. def forward(self, features: torch.Tensor) -> torch.Tensor:
  126. residual = features
  127. features = self.dwconv(features)
  128. features = features.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
  129. features = self.layernorm(features)
  130. features = self.pwconv1(features)
  131. features = self.act(features)
  132. features = self.pwconv2(features)
  133. if self.layer_scale_parameter is not None:
  134. features = self.layer_scale_parameter * features
  135. features = features.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
  136. features = residual + self.drop_path(features)
  137. return features
  138. class ConvNextStage(nn.Module):
  139. """ConvNeXT stage, consisting of an optional downsampling layer + multiple residual blocks.
  140. Args:
  141. config ([`ConvNextConfig`]): Model configuration class.
  142. in_channels (`int`): Number of input channels.
  143. out_channels (`int`): Number of output channels.
  144. depth (`int`): Number of residual blocks.
  145. drop_path_rates(`list[float]`): Stochastic depth rates for each layer.
  146. """
  147. def __init__(self, config, in_channels, out_channels, kernel_size=2, stride=2, depth=2, drop_path_rates=None):
  148. super().__init__()
  149. if in_channels != out_channels or stride > 1:
  150. self.downsampling_layer = nn.ModuleList(
  151. [
  152. ConvNextLayerNorm(in_channels, eps=1e-6, data_format="channels_first"),
  153. nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride),
  154. ]
  155. )
  156. else:
  157. self.downsampling_layer = nn.ModuleList()
  158. drop_path_rates = drop_path_rates or [0.0] * depth
  159. self.layers = nn.ModuleList(
  160. [ConvNextLayer(config, dim=out_channels, drop_path=drop_path_rates[j]) for j in range(depth)]
  161. )
  162. def forward(self, features: torch.Tensor) -> torch.Tensor:
  163. for layer in self.downsampling_layer:
  164. features = layer(features)
  165. for layer in self.layers:
  166. features = layer(features)
  167. return features
  168. class ConvNextEncoder(nn.Module):
  169. def __init__(self, config):
  170. super().__init__()
  171. self.stages = nn.ModuleList()
  172. drop_path_rates = [
  173. x.tolist()
  174. for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu").split(config.depths)
  175. ]
  176. prev_chs = config.hidden_sizes[0]
  177. for i in range(config.num_stages):
  178. out_chs = config.hidden_sizes[i]
  179. stage = ConvNextStage(
  180. config,
  181. in_channels=prev_chs,
  182. out_channels=out_chs,
  183. stride=2 if i > 0 else 1,
  184. depth=config.depths[i],
  185. drop_path_rates=drop_path_rates[i],
  186. )
  187. self.stages.append(stage)
  188. prev_chs = out_chs
  189. def forward(
  190. self, hidden_states: torch.Tensor, output_hidden_states: Optional[bool] = False
  191. ) -> BaseModelOutputWithNoAttention:
  192. all_hidden_states = [hidden_states] if output_hidden_states else None
  193. for layer_module in self.stages:
  194. hidden_states = layer_module(hidden_states)
  195. if all_hidden_states is not None:
  196. all_hidden_states.append(hidden_states)
  197. return BaseModelOutputWithNoAttention(last_hidden_state=hidden_states, hidden_states=all_hidden_states)
  198. @auto_docstring
  199. class ConvNextPreTrainedModel(PreTrainedModel):
  200. config: ConvNextConfig
  201. base_model_prefix = "convnext"
  202. main_input_name = "pixel_values"
  203. _no_split_modules = ["ConvNextLayer"]
  204. _can_record_outputs = {} # hidden states are collected explicitly
  205. def _init_weights(self, module):
  206. """Initialize the weights"""
  207. if isinstance(module, (nn.Linear, nn.Conv2d)):
  208. # Slightly different from the TF version which uses truncated_normal for initialization
  209. # cf https://github.com/pytorch/pytorch/pull/5617
  210. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  211. if module.bias is not None:
  212. module.bias.data.zero_()
  213. elif isinstance(module, (nn.LayerNorm, ConvNextLayerNorm)):
  214. module.bias.data.zero_()
  215. module.weight.data.fill_(1.0)
  216. elif isinstance(module, ConvNextLayer):
  217. if module.layer_scale_parameter is not None:
  218. module.layer_scale_parameter.data.fill_(self.config.layer_scale_init_value)
  219. @auto_docstring
  220. class ConvNextModel(ConvNextPreTrainedModel):
  221. def __init__(self, config):
  222. super().__init__(config)
  223. self.config = config
  224. self.embeddings = ConvNextEmbeddings(config)
  225. self.encoder = ConvNextEncoder(config)
  226. # final layernorm layer
  227. self.layernorm = nn.LayerNorm(config.hidden_sizes[-1], eps=config.layer_norm_eps)
  228. # Initialize weights and apply final processing
  229. self.post_init()
  230. @can_return_tuple
  231. @auto_docstring
  232. def forward(
  233. self, pixel_values: Optional[torch.FloatTensor] = None, output_hidden_states: Optional[bool] = None
  234. ) -> BaseModelOutputWithPoolingAndNoAttention:
  235. if output_hidden_states is None:
  236. output_hidden_states = self.config.output_hidden_states
  237. if pixel_values is None:
  238. raise ValueError("You have to specify pixel_values")
  239. embedding_output = self.embeddings(pixel_values)
  240. encoder_outputs: BaseModelOutputWithNoAttention = self.encoder(
  241. embedding_output, output_hidden_states=output_hidden_states
  242. )
  243. last_hidden_state = encoder_outputs.last_hidden_state
  244. # global average pooling, (N, C, H, W) -> (N, C)
  245. pooled_output = self.layernorm(last_hidden_state.mean([-2, -1]))
  246. return BaseModelOutputWithPoolingAndNoAttention(
  247. last_hidden_state=last_hidden_state,
  248. pooler_output=pooled_output,
  249. hidden_states=encoder_outputs.hidden_states,
  250. )
  251. @auto_docstring(
  252. custom_intro="""
  253. ConvNext Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
  254. ImageNet.
  255. """
  256. )
  257. class ConvNextForImageClassification(ConvNextPreTrainedModel):
  258. accepts_loss_kwargs = False
  259. def __init__(self, config):
  260. super().__init__(config)
  261. self.num_labels = config.num_labels
  262. self.convnext = ConvNextModel(config)
  263. # Classifier head
  264. if config.num_labels > 0:
  265. self.classifier = nn.Linear(config.hidden_sizes[-1], config.num_labels)
  266. else:
  267. self.classifier = nn.Identity()
  268. # Initialize weights and apply final processing
  269. self.post_init()
  270. @can_return_tuple
  271. @auto_docstring
  272. def forward(
  273. self, pixel_values: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, **kwargs
  274. ) -> ImageClassifierOutputWithNoAttention:
  275. r"""
  276. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  277. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  278. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  279. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  280. """
  281. outputs: BaseModelOutputWithPoolingAndNoAttention = self.convnext(pixel_values, **kwargs)
  282. pooled_output = outputs.pooler_output
  283. logits = self.classifier(pooled_output)
  284. loss = None
  285. if labels is not None:
  286. loss = self.loss_function(labels=labels, pooled_logits=logits, config=self.config)
  287. return ImageClassifierOutputWithNoAttention(
  288. loss=loss,
  289. logits=logits,
  290. hidden_states=outputs.hidden_states,
  291. )
  292. @auto_docstring(
  293. custom_intro="""
  294. ConvNeXt backbone, to be used with frameworks like DETR and MaskFormer.
  295. """
  296. )
  297. class ConvNextBackbone(ConvNextPreTrainedModel, BackboneMixin):
  298. has_attentions = False
  299. def __init__(self, config):
  300. super().__init__(config)
  301. super()._init_backbone(config)
  302. self.embeddings = ConvNextEmbeddings(config)
  303. self.encoder = ConvNextEncoder(config)
  304. self.num_features = [config.hidden_sizes[0]] + config.hidden_sizes
  305. # Add layer norms to hidden states of out_features
  306. hidden_states_norms = {}
  307. for stage, num_channels in zip(self._out_features, self.channels):
  308. hidden_states_norms[stage] = ConvNextLayerNorm(num_channels, data_format="channels_first")
  309. self.hidden_states_norms = nn.ModuleDict(hidden_states_norms)
  310. # initialize weights and apply final processing
  311. self.post_init()
  312. @can_return_tuple
  313. @auto_docstring
  314. def forward(
  315. self,
  316. pixel_values: torch.Tensor,
  317. output_hidden_states: Optional[bool] = None,
  318. ) -> BackboneOutput:
  319. r"""
  320. Examples:
  321. ```python
  322. >>> from transformers import AutoImageProcessor, AutoBackbone
  323. >>> import torch
  324. >>> from PIL import Image
  325. >>> import requests
  326. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  327. >>> image = Image.open(requests.get(url, stream=True).raw)
  328. >>> processor = AutoImageProcessor.from_pretrained("facebook/convnext-tiny-224")
  329. >>> model = AutoBackbone.from_pretrained("facebook/convnext-tiny-224")
  330. >>> inputs = processor(image, return_tensors="pt")
  331. >>> outputs = model(**inputs)
  332. ```"""
  333. if output_hidden_states is None:
  334. output_hidden_states = self.config.output_hidden_states
  335. embedding_output = self.embeddings(pixel_values)
  336. outputs: BaseModelOutputWithPoolingAndNoAttention = self.encoder(embedding_output, output_hidden_states=True)
  337. hidden_states = outputs.hidden_states
  338. feature_maps = []
  339. for stage, hidden_state in zip(self.stage_names, hidden_states):
  340. if stage in self.out_features:
  341. hidden_state = self.hidden_states_norms[stage](hidden_state)
  342. feature_maps.append(hidden_state)
  343. return BackboneOutput(
  344. feature_maps=tuple(feature_maps),
  345. hidden_states=hidden_states if output_hidden_states else None,
  346. )
  347. __all__ = ["ConvNextForImageClassification", "ConvNextModel", "ConvNextPreTrainedModel", "ConvNextBackbone"]