modeling_convnextv2.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447
  1. # coding=utf-8
  2. # Copyright 2023 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch ConvNextV2 model."""
  16. from typing import Optional
  17. import torch
  18. from torch import nn
  19. from ...activations import ACT2FN
  20. from ...modeling_outputs import (
  21. BackboneOutput,
  22. BaseModelOutputWithNoAttention,
  23. BaseModelOutputWithPoolingAndNoAttention,
  24. ImageClassifierOutputWithNoAttention,
  25. )
  26. from ...modeling_utils import PreTrainedModel
  27. from ...utils import auto_docstring, logging
  28. from ...utils.backbone_utils import BackboneMixin
  29. from ...utils.generic import can_return_tuple
  30. from .configuration_convnextv2 import ConvNextV2Config
  31. logger = logging.get_logger(__name__)
  32. # Copied from transformers.models.beit.modeling_beit.drop_path
  33. def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
  34. """
  35. Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  36. Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
  37. however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
  38. See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
  39. layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
  40. argument.
  41. """
  42. if drop_prob == 0.0 or not training:
  43. return input
  44. keep_prob = 1 - drop_prob
  45. shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
  46. random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
  47. random_tensor.floor_() # binarize
  48. output = input.div(keep_prob) * random_tensor
  49. return output
  50. # Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->ConvNextV2
  51. class ConvNextV2DropPath(nn.Module):
  52. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
  53. def __init__(self, drop_prob: Optional[float] = None) -> None:
  54. super().__init__()
  55. self.drop_prob = drop_prob
  56. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  57. return drop_path(hidden_states, self.drop_prob, self.training)
  58. def extra_repr(self) -> str:
  59. return f"p={self.drop_prob}"
  60. class ConvNextV2GRN(nn.Module):
  61. """GRN (Global Response Normalization) layer"""
  62. def __init__(self, dim: int):
  63. super().__init__()
  64. self.weight = nn.Parameter(torch.zeros(1, 1, 1, dim))
  65. self.bias = nn.Parameter(torch.zeros(1, 1, 1, dim))
  66. def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
  67. # Compute and normalize global spatial feature maps
  68. global_features = torch.linalg.vector_norm(hidden_states, ord=2, dim=(1, 2), keepdim=True)
  69. norm_features = global_features / (global_features.mean(dim=-1, keepdim=True) + 1e-6)
  70. hidden_states = self.weight * (hidden_states * norm_features) + self.bias + hidden_states
  71. return hidden_states
  72. # Copied from transformers.models.convnext.modeling_convnext.ConvNextLayerNorm with ConvNext->ConvNextV2
  73. class ConvNextV2LayerNorm(nn.LayerNorm):
  74. r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
  75. The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height,
  76. width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width).
  77. """
  78. def __init__(self, normalized_shape, *, eps=1e-6, data_format="channels_last", **kwargs):
  79. super().__init__(normalized_shape, eps=eps, **kwargs)
  80. if data_format not in ["channels_last", "channels_first"]:
  81. raise NotImplementedError(f"Unsupported data format: {data_format}")
  82. self.data_format = data_format
  83. def forward(self, features: torch.Tensor) -> torch.Tensor:
  84. """
  85. Args:
  86. features: Tensor of shape (batch_size, channels, height, width) OR (batch_size, height, width, channels)
  87. """
  88. if self.data_format == "channels_first":
  89. features = features.permute(0, 2, 3, 1)
  90. features = super().forward(features)
  91. features = features.permute(0, 3, 1, 2)
  92. else:
  93. features = super().forward(features)
  94. return features
  95. # Copied from transformers.models.convnext.modeling_convnext.ConvNextEmbeddings with ConvNext->ConvNextV2
  96. class ConvNextV2Embeddings(nn.Module):
  97. """This class is comparable to (and inspired by) the SwinEmbeddings class
  98. found in src/transformers/models/swin/modeling_swin.py.
  99. """
  100. def __init__(self, config):
  101. super().__init__()
  102. self.patch_embeddings = nn.Conv2d(
  103. config.num_channels, config.hidden_sizes[0], kernel_size=config.patch_size, stride=config.patch_size
  104. )
  105. self.layernorm = ConvNextV2LayerNorm(config.hidden_sizes[0], eps=1e-6, data_format="channels_first")
  106. self.num_channels = config.num_channels
  107. def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
  108. num_channels = pixel_values.shape[1]
  109. if num_channels != self.num_channels:
  110. raise ValueError(
  111. "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
  112. )
  113. embeddings = self.patch_embeddings(pixel_values)
  114. embeddings = self.layernorm(embeddings)
  115. return embeddings
  116. class ConvNextV2Layer(nn.Module):
  117. """This corresponds to the `Block` class in the original implementation.
  118. There are two equivalent implementations: [DwConv, LayerNorm (channels_first), Conv, GELU,1x1 Conv]; all in (N, C,
  119. H, W) (2) [DwConv, Permute to (N, H, W, C), LayerNorm (channels_last), Linear, GELU, Linear]; Permute back
  120. The authors used (2) as they find it slightly faster in PyTorch.
  121. Args:
  122. config ([`ConvNextV2Config`]): Model configuration class.
  123. dim (`int`): Number of input channels.
  124. drop_path (`float`): Stochastic depth rate. Default: 0.0.
  125. """
  126. def __init__(self, config, dim, drop_path=0):
  127. super().__init__()
  128. # depthwise conv
  129. self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim)
  130. self.layernorm = ConvNextV2LayerNorm(dim, eps=1e-6)
  131. # pointwise/1x1 convs, implemented with linear layers
  132. self.pwconv1 = nn.Linear(dim, 4 * dim)
  133. self.act = ACT2FN[config.hidden_act]
  134. self.grn = ConvNextV2GRN(4 * dim)
  135. self.pwconv2 = nn.Linear(4 * dim, dim)
  136. self.drop_path = ConvNextV2DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
  137. def forward(self, features: torch.Tensor) -> torch.Tensor:
  138. residual = features
  139. features = self.dwconv(features)
  140. # (batch_size, num_channels, height, width) -> (batch_size, height, width, num_channels)
  141. features = features.permute(0, 2, 3, 1)
  142. features = self.layernorm(features)
  143. features = self.pwconv1(features)
  144. features = self.act(features)
  145. features = self.grn(features)
  146. features = self.pwconv2(features)
  147. # (batch_size, height, width, num_channels) -> (batch_size, num_channels, height, width)
  148. features = features.permute(0, 3, 1, 2)
  149. features = residual + self.drop_path(features)
  150. return features
  151. # Copied from transformers.models.convnext.modeling_convnext.ConvNextStage with ConvNeXT->ConvNeXTV2, ConvNext->ConvNextV2
  152. class ConvNextV2Stage(nn.Module):
  153. """ConvNeXTV2 stage, consisting of an optional downsampling layer + multiple residual blocks.
  154. Args:
  155. config ([`ConvNextV2Config`]): Model configuration class.
  156. in_channels (`int`): Number of input channels.
  157. out_channels (`int`): Number of output channels.
  158. depth (`int`): Number of residual blocks.
  159. drop_path_rates(`list[float]`): Stochastic depth rates for each layer.
  160. """
  161. def __init__(self, config, in_channels, out_channels, kernel_size=2, stride=2, depth=2, drop_path_rates=None):
  162. super().__init__()
  163. if in_channels != out_channels or stride > 1:
  164. self.downsampling_layer = nn.ModuleList(
  165. [
  166. ConvNextV2LayerNorm(in_channels, eps=1e-6, data_format="channels_first"),
  167. nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride),
  168. ]
  169. )
  170. else:
  171. self.downsampling_layer = nn.ModuleList()
  172. drop_path_rates = drop_path_rates or [0.0] * depth
  173. self.layers = nn.ModuleList(
  174. [ConvNextV2Layer(config, dim=out_channels, drop_path=drop_path_rates[j]) for j in range(depth)]
  175. )
  176. def forward(self, features: torch.Tensor) -> torch.Tensor:
  177. for layer in self.downsampling_layer:
  178. features = layer(features)
  179. for layer in self.layers:
  180. features = layer(features)
  181. return features
  182. # Copied from transformers.models.convnext.modeling_convnext.ConvNextEncoder with ConvNext->ConvNextV2
  183. class ConvNextV2Encoder(nn.Module):
  184. def __init__(self, config):
  185. super().__init__()
  186. self.stages = nn.ModuleList()
  187. drop_path_rates = [
  188. x.tolist()
  189. for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu").split(config.depths)
  190. ]
  191. prev_chs = config.hidden_sizes[0]
  192. for i in range(config.num_stages):
  193. out_chs = config.hidden_sizes[i]
  194. stage = ConvNextV2Stage(
  195. config,
  196. in_channels=prev_chs,
  197. out_channels=out_chs,
  198. stride=2 if i > 0 else 1,
  199. depth=config.depths[i],
  200. drop_path_rates=drop_path_rates[i],
  201. )
  202. self.stages.append(stage)
  203. prev_chs = out_chs
  204. def forward(
  205. self, hidden_states: torch.Tensor, output_hidden_states: Optional[bool] = False
  206. ) -> BaseModelOutputWithNoAttention:
  207. all_hidden_states = [hidden_states] if output_hidden_states else None
  208. for layer_module in self.stages:
  209. hidden_states = layer_module(hidden_states)
  210. if all_hidden_states is not None:
  211. all_hidden_states.append(hidden_states)
  212. return BaseModelOutputWithNoAttention(last_hidden_state=hidden_states, hidden_states=all_hidden_states)
  213. @auto_docstring
  214. class ConvNextV2PreTrainedModel(PreTrainedModel):
  215. config: ConvNextV2Config
  216. base_model_prefix = "convnextv2"
  217. main_input_name = "pixel_values"
  218. _no_split_modules = ["ConvNextV2Layer"]
  219. def _init_weights(self, module):
  220. """Initialize the weights"""
  221. if isinstance(module, (nn.Linear, nn.Conv2d)):
  222. # Slightly different from the TF version which uses truncated_normal for initialization
  223. # cf https://github.com/pytorch/pytorch/pull/5617
  224. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  225. if module.bias is not None:
  226. module.bias.data.zero_()
  227. elif isinstance(module, (nn.LayerNorm, ConvNextV2LayerNorm)):
  228. module.bias.data.zero_()
  229. module.weight.data.fill_(1.0)
  230. elif isinstance(module, ConvNextV2GRN):
  231. module.weight.data.zero_()
  232. module.bias.data.zero_()
  233. @auto_docstring
  234. # Copied from transformers.models.convnext.modeling_convnext.ConvNextModel with CONVNEXT->CONVNEXTV2, ConvNext->ConvNextV2
  235. class ConvNextV2Model(ConvNextV2PreTrainedModel):
  236. def __init__(self, config):
  237. super().__init__(config)
  238. self.config = config
  239. self.embeddings = ConvNextV2Embeddings(config)
  240. self.encoder = ConvNextV2Encoder(config)
  241. # final layernorm layer
  242. self.layernorm = nn.LayerNorm(config.hidden_sizes[-1], eps=config.layer_norm_eps)
  243. # Initialize weights and apply final processing
  244. self.post_init()
  245. @can_return_tuple
  246. @auto_docstring
  247. def forward(
  248. self, pixel_values: Optional[torch.FloatTensor] = None, output_hidden_states: Optional[bool] = None
  249. ) -> BaseModelOutputWithPoolingAndNoAttention:
  250. if output_hidden_states is None:
  251. output_hidden_states = self.config.output_hidden_states
  252. if pixel_values is None:
  253. raise ValueError("You have to specify pixel_values")
  254. embedding_output = self.embeddings(pixel_values)
  255. encoder_outputs: BaseModelOutputWithNoAttention = self.encoder(
  256. embedding_output, output_hidden_states=output_hidden_states
  257. )
  258. last_hidden_state = encoder_outputs.last_hidden_state
  259. # global average pooling, (N, C, H, W) -> (N, C)
  260. pooled_output = self.layernorm(last_hidden_state.mean([-2, -1]))
  261. return BaseModelOutputWithPoolingAndNoAttention(
  262. last_hidden_state=last_hidden_state,
  263. pooler_output=pooled_output,
  264. hidden_states=encoder_outputs.hidden_states,
  265. )
  266. @auto_docstring(
  267. custom_intro="""
  268. ConvNextV2 Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
  269. ImageNet.
  270. """
  271. )
  272. # Copied from transformers.models.convnext.modeling_convnext.ConvNextForImageClassification with CONVNEXT->CONVNEXTV2,ConvNext->ConvNextV2,convnext->convnextv2
  273. class ConvNextV2ForImageClassification(ConvNextV2PreTrainedModel):
  274. accepts_loss_kwargs = False
  275. def __init__(self, config):
  276. super().__init__(config)
  277. self.num_labels = config.num_labels
  278. self.convnextv2 = ConvNextV2Model(config)
  279. # Classifier head
  280. if config.num_labels > 0:
  281. self.classifier = nn.Linear(config.hidden_sizes[-1], config.num_labels)
  282. else:
  283. self.classifier = nn.Identity()
  284. # Initialize weights and apply final processing
  285. self.post_init()
  286. @can_return_tuple
  287. @auto_docstring
  288. def forward(
  289. self, pixel_values: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, **kwargs
  290. ) -> ImageClassifierOutputWithNoAttention:
  291. r"""
  292. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  293. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  294. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  295. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  296. """
  297. outputs: BaseModelOutputWithPoolingAndNoAttention = self.convnextv2(pixel_values, **kwargs)
  298. pooled_output = outputs.pooler_output
  299. logits = self.classifier(pooled_output)
  300. loss = None
  301. if labels is not None:
  302. loss = self.loss_function(labels=labels, pooled_logits=logits, config=self.config)
  303. return ImageClassifierOutputWithNoAttention(
  304. loss=loss,
  305. logits=logits,
  306. hidden_states=outputs.hidden_states,
  307. )
  308. @auto_docstring(
  309. custom_intro="""
  310. ConvNeXT V2 backbone, to be used with frameworks like DETR and MaskFormer.
  311. """
  312. )
  313. # Copied from transformers.models.convnext.modeling_convnext.ConvNextBackbone with CONVNEXT->CONVNEXTV2,ConvNext->ConvNextV2,facebook/convnext-tiny-224->facebook/convnextv2-tiny-1k-224
  314. class ConvNextV2Backbone(ConvNextV2PreTrainedModel, BackboneMixin):
  315. has_attentions = False
  316. def __init__(self, config):
  317. super().__init__(config)
  318. super()._init_backbone(config)
  319. self.embeddings = ConvNextV2Embeddings(config)
  320. self.encoder = ConvNextV2Encoder(config)
  321. self.num_features = [config.hidden_sizes[0]] + config.hidden_sizes
  322. # Add layer norms to hidden states of out_features
  323. hidden_states_norms = {}
  324. for stage, num_channels in zip(self._out_features, self.channels):
  325. hidden_states_norms[stage] = ConvNextV2LayerNorm(num_channels, data_format="channels_first")
  326. self.hidden_states_norms = nn.ModuleDict(hidden_states_norms)
  327. # initialize weights and apply final processing
  328. self.post_init()
  329. @can_return_tuple
  330. @auto_docstring
  331. def forward(
  332. self,
  333. pixel_values: torch.Tensor,
  334. output_hidden_states: Optional[bool] = None,
  335. ) -> BackboneOutput:
  336. r"""
  337. Examples:
  338. ```python
  339. >>> from transformers import AutoImageProcessor, AutoBackbone
  340. >>> import torch
  341. >>> from PIL import Image
  342. >>> import requests
  343. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  344. >>> image = Image.open(requests.get(url, stream=True).raw)
  345. >>> processor = AutoImageProcessor.from_pretrained("facebook/convnextv2-tiny-1k-224")
  346. >>> model = AutoBackbone.from_pretrained("facebook/convnextv2-tiny-1k-224")
  347. >>> inputs = processor(image, return_tensors="pt")
  348. >>> outputs = model(**inputs)
  349. ```"""
  350. if output_hidden_states is None:
  351. output_hidden_states = self.config.output_hidden_states
  352. embedding_output = self.embeddings(pixel_values)
  353. outputs: BaseModelOutputWithPoolingAndNoAttention = self.encoder(embedding_output, output_hidden_states=True)
  354. hidden_states = outputs.hidden_states
  355. feature_maps = []
  356. for stage, hidden_state in zip(self.stage_names, hidden_states):
  357. if stage in self.out_features:
  358. hidden_state = self.hidden_states_norms[stage](hidden_state)
  359. feature_maps.append(hidden_state)
  360. return BackboneOutput(
  361. feature_maps=tuple(feature_maps),
  362. hidden_states=hidden_states if output_hidden_states else None,
  363. )
  364. __all__ = ["ConvNextV2ForImageClassification", "ConvNextV2Model", "ConvNextV2PreTrainedModel", "ConvNextV2Backbone"]