modeling_dinat.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855
  1. # coding=utf-8
  2. # Copyright 2022 SHI Labs and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch Dilated Neighborhood Attention Transformer model."""
  16. import math
  17. from dataclasses import dataclass
  18. from typing import Optional, Union
  19. import torch
  20. from torch import nn
  21. from ...activations import ACT2FN
  22. from ...modeling_outputs import BackboneOutput
  23. from ...modeling_utils import PreTrainedModel
  24. from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
  25. from ...utils import (
  26. ModelOutput,
  27. OptionalDependencyNotAvailable,
  28. auto_docstring,
  29. is_natten_available,
  30. logging,
  31. requires_backends,
  32. )
  33. from ...utils.backbone_utils import BackboneMixin
  34. from .configuration_dinat import DinatConfig
  35. if is_natten_available():
  36. from natten.functional import natten2dav, natten2dqkrpb
  37. else:
  38. def natten2dqkrpb(*args, **kwargs):
  39. raise OptionalDependencyNotAvailable()
  40. def natten2dav(*args, **kwargs):
  41. raise OptionalDependencyNotAvailable()
  42. logger = logging.get_logger(__name__)
  43. # drop_path and DinatDropPath are from the timm library.
  44. @dataclass
  45. @auto_docstring(
  46. custom_intro="""
  47. Dinat encoder's outputs, with potential hidden states and attentions.
  48. """
  49. )
  50. class DinatEncoderOutput(ModelOutput):
  51. r"""
  52. reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  53. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
  54. shape `(batch_size, hidden_size, height, width)`.
  55. Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
  56. include the spatial dimensions.
  57. """
  58. last_hidden_state: Optional[torch.FloatTensor] = None
  59. hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  60. attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  61. reshaped_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  62. @dataclass
  63. @auto_docstring(
  64. custom_intro="""
  65. Dinat model's outputs that also contains a pooling of the last hidden states.
  66. """
  67. )
  68. class DinatModelOutput(ModelOutput):
  69. r"""
  70. pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*, returned when `add_pooling_layer=True` is passed):
  71. Average pooling of the last layer hidden-state.
  72. reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  73. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
  74. shape `(batch_size, hidden_size, height, width)`.
  75. Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
  76. include the spatial dimensions.
  77. """
  78. last_hidden_state: Optional[torch.FloatTensor] = None
  79. pooler_output: Optional[torch.FloatTensor] = None
  80. hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  81. attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  82. reshaped_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  83. @dataclass
  84. @auto_docstring(
  85. custom_intro="""
  86. Dinat outputs for image classification.
  87. """
  88. )
  89. class DinatImageClassifierOutput(ModelOutput):
  90. r"""
  91. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  92. Classification (or regression if config.num_labels==1) loss.
  93. logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
  94. Classification (or regression if config.num_labels==1) scores (before SoftMax).
  95. reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  96. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
  97. shape `(batch_size, hidden_size, height, width)`.
  98. Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
  99. include the spatial dimensions.
  100. """
  101. loss: Optional[torch.FloatTensor] = None
  102. logits: Optional[torch.FloatTensor] = None
  103. hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  104. attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  105. reshaped_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  106. class DinatEmbeddings(nn.Module):
  107. """
  108. Construct the patch and position embeddings.
  109. """
  110. def __init__(self, config):
  111. super().__init__()
  112. self.patch_embeddings = DinatPatchEmbeddings(config)
  113. self.norm = nn.LayerNorm(config.embed_dim)
  114. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  115. def forward(self, pixel_values: Optional[torch.FloatTensor]) -> tuple[torch.Tensor]:
  116. embeddings = self.patch_embeddings(pixel_values)
  117. embeddings = self.norm(embeddings)
  118. embeddings = self.dropout(embeddings)
  119. return embeddings
  120. class DinatPatchEmbeddings(nn.Module):
  121. """
  122. This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
  123. `hidden_states` (patch embeddings) of shape `(batch_size, height, width, hidden_size)` to be consumed by a
  124. Transformer.
  125. """
  126. def __init__(self, config):
  127. super().__init__()
  128. patch_size = config.patch_size
  129. num_channels, hidden_size = config.num_channels, config.embed_dim
  130. self.num_channels = num_channels
  131. if patch_size == 4:
  132. pass
  133. else:
  134. # TODO: Support arbitrary patch sizes.
  135. raise ValueError("Dinat only supports patch size of 4 at the moment.")
  136. self.projection = nn.Sequential(
  137. nn.Conv2d(self.num_channels, hidden_size // 2, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
  138. nn.Conv2d(hidden_size // 2, hidden_size, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
  139. )
  140. def forward(self, pixel_values: Optional[torch.FloatTensor]) -> torch.Tensor:
  141. _, num_channels, height, width = pixel_values.shape
  142. if num_channels != self.num_channels:
  143. raise ValueError(
  144. "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
  145. )
  146. embeddings = self.projection(pixel_values)
  147. embeddings = embeddings.permute(0, 2, 3, 1)
  148. return embeddings
  149. class DinatDownsampler(nn.Module):
  150. """
  151. Convolutional Downsampling Layer.
  152. Args:
  153. dim (`int`):
  154. Number of input channels.
  155. norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`):
  156. Normalization layer class.
  157. """
  158. def __init__(self, dim: int, norm_layer: nn.Module = nn.LayerNorm) -> None:
  159. super().__init__()
  160. self.dim = dim
  161. self.reduction = nn.Conv2d(dim, 2 * dim, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  162. self.norm = norm_layer(2 * dim)
  163. def forward(self, input_feature: torch.Tensor) -> torch.Tensor:
  164. input_feature = self.reduction(input_feature.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
  165. input_feature = self.norm(input_feature)
  166. return input_feature
  167. # Copied from transformers.models.beit.modeling_beit.drop_path
  168. def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
  169. """
  170. Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  171. Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
  172. however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
  173. See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
  174. layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
  175. argument.
  176. """
  177. if drop_prob == 0.0 or not training:
  178. return input
  179. keep_prob = 1 - drop_prob
  180. shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
  181. random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
  182. random_tensor.floor_() # binarize
  183. output = input.div(keep_prob) * random_tensor
  184. return output
  185. # Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->Dinat
  186. class DinatDropPath(nn.Module):
  187. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
  188. def __init__(self, drop_prob: Optional[float] = None) -> None:
  189. super().__init__()
  190. self.drop_prob = drop_prob
  191. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  192. return drop_path(hidden_states, self.drop_prob, self.training)
  193. def extra_repr(self) -> str:
  194. return f"p={self.drop_prob}"
  195. class NeighborhoodAttention(nn.Module):
  196. def __init__(self, config, dim, num_heads, kernel_size, dilation):
  197. super().__init__()
  198. if dim % num_heads != 0:
  199. raise ValueError(
  200. f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})"
  201. )
  202. self.num_attention_heads = num_heads
  203. self.attention_head_size = int(dim / num_heads)
  204. self.all_head_size = self.num_attention_heads * self.attention_head_size
  205. self.kernel_size = kernel_size
  206. self.dilation = dilation
  207. # rpb is learnable relative positional biases; same concept is used Swin.
  208. self.rpb = nn.Parameter(torch.zeros(num_heads, (2 * self.kernel_size - 1), (2 * self.kernel_size - 1)))
  209. self.query = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
  210. self.key = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
  211. self.value = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
  212. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  213. def forward(
  214. self,
  215. hidden_states: torch.Tensor,
  216. output_attentions: Optional[bool] = False,
  217. ) -> tuple[torch.Tensor]:
  218. batch_size, seq_length, _ = hidden_states.shape
  219. query_layer = (
  220. self.query(hidden_states)
  221. .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
  222. .transpose(1, 2)
  223. )
  224. key_layer = (
  225. self.key(hidden_states)
  226. .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
  227. .transpose(1, 2)
  228. )
  229. value_layer = (
  230. self.value(hidden_states)
  231. .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
  232. .transpose(1, 2)
  233. )
  234. # Apply the scale factor before computing attention weights. It's usually more efficient because
  235. # attention weights are typically a bigger tensor compared to query.
  236. # It gives identical results because scalars are commutable in matrix multiplication.
  237. query_layer = query_layer / math.sqrt(self.attention_head_size)
  238. # Compute NA between "query" and "key" to get the raw attention scores, and add relative positional biases.
  239. attention_scores = natten2dqkrpb(query_layer, key_layer, self.rpb, self.kernel_size, self.dilation)
  240. # Normalize the attention scores to probabilities.
  241. attention_probs = nn.functional.softmax(attention_scores, dim=-1)
  242. # This is actually dropping out entire tokens to attend to, which might
  243. # seem a bit unusual, but is taken from the original Transformer paper.
  244. attention_probs = self.dropout(attention_probs)
  245. context_layer = natten2dav(attention_probs, value_layer, self.kernel_size, self.dilation)
  246. context_layer = context_layer.permute(0, 2, 3, 1, 4).contiguous()
  247. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  248. context_layer = context_layer.view(new_context_layer_shape)
  249. outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
  250. return outputs
  251. class NeighborhoodAttentionOutput(nn.Module):
  252. def __init__(self, config, dim):
  253. super().__init__()
  254. self.dense = nn.Linear(dim, dim)
  255. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  256. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  257. hidden_states = self.dense(hidden_states)
  258. hidden_states = self.dropout(hidden_states)
  259. return hidden_states
  260. class NeighborhoodAttentionModule(nn.Module):
  261. def __init__(self, config, dim, num_heads, kernel_size, dilation):
  262. super().__init__()
  263. self.self = NeighborhoodAttention(config, dim, num_heads, kernel_size, dilation)
  264. self.output = NeighborhoodAttentionOutput(config, dim)
  265. self.pruned_heads = set()
  266. def prune_heads(self, heads):
  267. if len(heads) == 0:
  268. return
  269. heads, index = find_pruneable_heads_and_indices(
  270. heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
  271. )
  272. # Prune linear layers
  273. self.self.query = prune_linear_layer(self.self.query, index)
  274. self.self.key = prune_linear_layer(self.self.key, index)
  275. self.self.value = prune_linear_layer(self.self.value, index)
  276. self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
  277. # Update hyper params and store pruned heads
  278. self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
  279. self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
  280. self.pruned_heads = self.pruned_heads.union(heads)
  281. def forward(
  282. self,
  283. hidden_states: torch.Tensor,
  284. output_attentions: Optional[bool] = False,
  285. ) -> tuple[torch.Tensor]:
  286. self_outputs = self.self(hidden_states, output_attentions)
  287. attention_output = self.output(self_outputs[0], hidden_states)
  288. outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
  289. return outputs
  290. class DinatIntermediate(nn.Module):
  291. def __init__(self, config, dim):
  292. super().__init__()
  293. self.dense = nn.Linear(dim, int(config.mlp_ratio * dim))
  294. if isinstance(config.hidden_act, str):
  295. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  296. else:
  297. self.intermediate_act_fn = config.hidden_act
  298. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  299. hidden_states = self.dense(hidden_states)
  300. hidden_states = self.intermediate_act_fn(hidden_states)
  301. return hidden_states
  302. class DinatOutput(nn.Module):
  303. def __init__(self, config, dim):
  304. super().__init__()
  305. self.dense = nn.Linear(int(config.mlp_ratio * dim), dim)
  306. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  307. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  308. hidden_states = self.dense(hidden_states)
  309. hidden_states = self.dropout(hidden_states)
  310. return hidden_states
  311. class DinatLayer(nn.Module):
  312. def __init__(self, config, dim, num_heads, dilation, drop_path_rate=0.0):
  313. super().__init__()
  314. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  315. self.kernel_size = config.kernel_size
  316. self.dilation = dilation
  317. self.window_size = self.kernel_size * self.dilation
  318. self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps)
  319. self.attention = NeighborhoodAttentionModule(
  320. config, dim, num_heads, kernel_size=self.kernel_size, dilation=self.dilation
  321. )
  322. self.drop_path = DinatDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
  323. self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps)
  324. self.intermediate = DinatIntermediate(config, dim)
  325. self.output = DinatOutput(config, dim)
  326. self.layer_scale_parameters = (
  327. nn.Parameter(config.layer_scale_init_value * torch.ones((2, dim)), requires_grad=True)
  328. if config.layer_scale_init_value > 0
  329. else None
  330. )
  331. def maybe_pad(self, hidden_states, height, width):
  332. window_size = self.window_size
  333. pad_values = (0, 0, 0, 0, 0, 0)
  334. if height < window_size or width < window_size:
  335. pad_l = pad_t = 0
  336. pad_r = max(0, window_size - width)
  337. pad_b = max(0, window_size - height)
  338. pad_values = (0, 0, pad_l, pad_r, pad_t, pad_b)
  339. hidden_states = nn.functional.pad(hidden_states, pad_values)
  340. return hidden_states, pad_values
  341. def forward(
  342. self,
  343. hidden_states: torch.Tensor,
  344. output_attentions: Optional[bool] = False,
  345. ) -> tuple[torch.Tensor, torch.Tensor]:
  346. batch_size, height, width, channels = hidden_states.size()
  347. shortcut = hidden_states
  348. hidden_states = self.layernorm_before(hidden_states)
  349. # pad hidden_states if they are smaller than kernel size x dilation
  350. hidden_states, pad_values = self.maybe_pad(hidden_states, height, width)
  351. _, height_pad, width_pad, _ = hidden_states.shape
  352. attention_outputs = self.attention(hidden_states, output_attentions=output_attentions)
  353. attention_output = attention_outputs[0]
  354. was_padded = pad_values[3] > 0 or pad_values[5] > 0
  355. if was_padded:
  356. attention_output = attention_output[:, :height, :width, :].contiguous()
  357. if self.layer_scale_parameters is not None:
  358. attention_output = self.layer_scale_parameters[0] * attention_output
  359. hidden_states = shortcut + self.drop_path(attention_output)
  360. layer_output = self.layernorm_after(hidden_states)
  361. layer_output = self.output(self.intermediate(layer_output))
  362. if self.layer_scale_parameters is not None:
  363. layer_output = self.layer_scale_parameters[1] * layer_output
  364. layer_output = hidden_states + self.drop_path(layer_output)
  365. layer_outputs = (layer_output, attention_outputs[1]) if output_attentions else (layer_output,)
  366. return layer_outputs
  367. class DinatStage(nn.Module):
  368. def __init__(self, config, dim, depth, num_heads, dilations, drop_path_rate, downsample):
  369. super().__init__()
  370. self.config = config
  371. self.dim = dim
  372. self.layers = nn.ModuleList(
  373. [
  374. DinatLayer(
  375. config=config,
  376. dim=dim,
  377. num_heads=num_heads,
  378. dilation=dilations[i],
  379. drop_path_rate=drop_path_rate[i],
  380. )
  381. for i in range(depth)
  382. ]
  383. )
  384. # patch merging layer
  385. if downsample is not None:
  386. self.downsample = downsample(dim=dim, norm_layer=nn.LayerNorm)
  387. else:
  388. self.downsample = None
  389. self.pointing = False
  390. def forward(
  391. self,
  392. hidden_states: torch.Tensor,
  393. output_attentions: Optional[bool] = False,
  394. ) -> tuple[torch.Tensor]:
  395. _, height, width, _ = hidden_states.size()
  396. for i, layer_module in enumerate(self.layers):
  397. layer_outputs = layer_module(hidden_states, output_attentions)
  398. hidden_states = layer_outputs[0]
  399. hidden_states_before_downsampling = hidden_states
  400. if self.downsample is not None:
  401. hidden_states = self.downsample(hidden_states_before_downsampling)
  402. stage_outputs = (hidden_states, hidden_states_before_downsampling)
  403. if output_attentions:
  404. stage_outputs += layer_outputs[1:]
  405. return stage_outputs
  406. class DinatEncoder(nn.Module):
  407. def __init__(self, config):
  408. super().__init__()
  409. self.num_levels = len(config.depths)
  410. self.config = config
  411. dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")]
  412. self.levels = nn.ModuleList(
  413. [
  414. DinatStage(
  415. config=config,
  416. dim=int(config.embed_dim * 2**i_layer),
  417. depth=config.depths[i_layer],
  418. num_heads=config.num_heads[i_layer],
  419. dilations=config.dilations[i_layer],
  420. drop_path_rate=dpr[sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])],
  421. downsample=DinatDownsampler if (i_layer < self.num_levels - 1) else None,
  422. )
  423. for i_layer in range(self.num_levels)
  424. ]
  425. )
  426. def forward(
  427. self,
  428. hidden_states: torch.Tensor,
  429. output_attentions: Optional[bool] = False,
  430. output_hidden_states: Optional[bool] = False,
  431. output_hidden_states_before_downsampling: Optional[bool] = False,
  432. return_dict: Optional[bool] = True,
  433. ) -> Union[tuple, DinatEncoderOutput]:
  434. all_hidden_states = () if output_hidden_states else None
  435. all_reshaped_hidden_states = () if output_hidden_states else None
  436. all_self_attentions = () if output_attentions else None
  437. if output_hidden_states:
  438. # rearrange b h w c -> b c h w
  439. reshaped_hidden_state = hidden_states.permute(0, 3, 1, 2)
  440. all_hidden_states += (hidden_states,)
  441. all_reshaped_hidden_states += (reshaped_hidden_state,)
  442. for i, layer_module in enumerate(self.levels):
  443. layer_outputs = layer_module(hidden_states, output_attentions)
  444. hidden_states = layer_outputs[0]
  445. hidden_states_before_downsampling = layer_outputs[1]
  446. if output_hidden_states and output_hidden_states_before_downsampling:
  447. # rearrange b h w c -> b c h w
  448. reshaped_hidden_state = hidden_states_before_downsampling.permute(0, 3, 1, 2)
  449. all_hidden_states += (hidden_states_before_downsampling,)
  450. all_reshaped_hidden_states += (reshaped_hidden_state,)
  451. elif output_hidden_states and not output_hidden_states_before_downsampling:
  452. # rearrange b h w c -> b c h w
  453. reshaped_hidden_state = hidden_states.permute(0, 3, 1, 2)
  454. all_hidden_states += (hidden_states,)
  455. all_reshaped_hidden_states += (reshaped_hidden_state,)
  456. if output_attentions:
  457. all_self_attentions += layer_outputs[2:]
  458. if not return_dict:
  459. return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
  460. return DinatEncoderOutput(
  461. last_hidden_state=hidden_states,
  462. hidden_states=all_hidden_states,
  463. attentions=all_self_attentions,
  464. reshaped_hidden_states=all_reshaped_hidden_states,
  465. )
  466. @auto_docstring
  467. class DinatPreTrainedModel(PreTrainedModel):
  468. config: DinatConfig
  469. base_model_prefix = "dinat"
  470. main_input_name = "pixel_values"
  471. def _init_weights(self, module):
  472. """Initialize the weights"""
  473. if isinstance(module, (nn.Linear, nn.Conv2d)):
  474. # Slightly different from the TF version which uses truncated_normal for initialization
  475. # cf https://github.com/pytorch/pytorch/pull/5617
  476. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  477. if module.bias is not None:
  478. module.bias.data.zero_()
  479. elif isinstance(module, nn.LayerNorm):
  480. module.bias.data.zero_()
  481. module.weight.data.fill_(1.0)
  482. @auto_docstring
  483. class DinatModel(DinatPreTrainedModel):
  484. def __init__(self, config, add_pooling_layer=True):
  485. r"""
  486. add_pooling_layer (bool, *optional*, defaults to `True`):
  487. Whether to add a pooling layer
  488. """
  489. super().__init__(config)
  490. requires_backends(self, ["natten"])
  491. self.config = config
  492. self.num_levels = len(config.depths)
  493. self.num_features = int(config.embed_dim * 2 ** (self.num_levels - 1))
  494. self.embeddings = DinatEmbeddings(config)
  495. self.encoder = DinatEncoder(config)
  496. self.layernorm = nn.LayerNorm(self.num_features, eps=config.layer_norm_eps)
  497. self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None
  498. # Initialize weights and apply final processing
  499. self.post_init()
  500. def get_input_embeddings(self):
  501. return self.embeddings.patch_embeddings
  502. def _prune_heads(self, heads_to_prune):
  503. """
  504. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  505. class PreTrainedModel
  506. """
  507. for layer, heads in heads_to_prune.items():
  508. self.encoder.layer[layer].attention.prune_heads(heads)
  509. @auto_docstring
  510. def forward(
  511. self,
  512. pixel_values: Optional[torch.FloatTensor] = None,
  513. output_attentions: Optional[bool] = None,
  514. output_hidden_states: Optional[bool] = None,
  515. return_dict: Optional[bool] = None,
  516. ) -> Union[tuple, DinatModelOutput]:
  517. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  518. output_hidden_states = (
  519. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  520. )
  521. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  522. if pixel_values is None:
  523. raise ValueError("You have to specify pixel_values")
  524. embedding_output = self.embeddings(pixel_values)
  525. encoder_outputs = self.encoder(
  526. embedding_output,
  527. output_attentions=output_attentions,
  528. output_hidden_states=output_hidden_states,
  529. return_dict=return_dict,
  530. )
  531. sequence_output = encoder_outputs[0]
  532. sequence_output = self.layernorm(sequence_output)
  533. pooled_output = None
  534. if self.pooler is not None:
  535. pooled_output = self.pooler(sequence_output.flatten(1, 2).transpose(1, 2))
  536. pooled_output = torch.flatten(pooled_output, 1)
  537. if not return_dict:
  538. output = (sequence_output, pooled_output) + encoder_outputs[1:]
  539. return output
  540. return DinatModelOutput(
  541. last_hidden_state=sequence_output,
  542. pooler_output=pooled_output,
  543. hidden_states=encoder_outputs.hidden_states,
  544. attentions=encoder_outputs.attentions,
  545. reshaped_hidden_states=encoder_outputs.reshaped_hidden_states,
  546. )
  547. @auto_docstring(
  548. custom_intro="""
  549. Dinat Model transformer with an image classification head on top (a linear layer on top of the final hidden state
  550. of the [CLS] token) e.g. for ImageNet.
  551. """
  552. )
  553. class DinatForImageClassification(DinatPreTrainedModel):
  554. def __init__(self, config):
  555. super().__init__(config)
  556. requires_backends(self, ["natten"])
  557. self.num_labels = config.num_labels
  558. self.dinat = DinatModel(config)
  559. # Classifier head
  560. self.classifier = (
  561. nn.Linear(self.dinat.num_features, config.num_labels) if config.num_labels > 0 else nn.Identity()
  562. )
  563. # Initialize weights and apply final processing
  564. self.post_init()
  565. @auto_docstring
  566. def forward(
  567. self,
  568. pixel_values: Optional[torch.FloatTensor] = None,
  569. labels: Optional[torch.LongTensor] = None,
  570. output_attentions: Optional[bool] = None,
  571. output_hidden_states: Optional[bool] = None,
  572. return_dict: Optional[bool] = None,
  573. ) -> Union[tuple, DinatImageClassifierOutput]:
  574. r"""
  575. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  576. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  577. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  578. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  579. """
  580. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  581. outputs = self.dinat(
  582. pixel_values,
  583. output_attentions=output_attentions,
  584. output_hidden_states=output_hidden_states,
  585. return_dict=return_dict,
  586. )
  587. pooled_output = outputs[1]
  588. logits = self.classifier(pooled_output)
  589. loss = None
  590. if labels is not None:
  591. loss = self.loss_function(labels, logits, self.config)
  592. if not return_dict:
  593. output = (logits,) + outputs[2:]
  594. return ((loss,) + output) if loss is not None else output
  595. return DinatImageClassifierOutput(
  596. loss=loss,
  597. logits=logits,
  598. hidden_states=outputs.hidden_states,
  599. attentions=outputs.attentions,
  600. reshaped_hidden_states=outputs.reshaped_hidden_states,
  601. )
  602. @auto_docstring(
  603. custom_intro="""
  604. NAT backbone, to be used with frameworks like DETR and MaskFormer.
  605. """
  606. )
  607. class DinatBackbone(DinatPreTrainedModel, BackboneMixin):
  608. def __init__(self, config):
  609. super().__init__(config)
  610. super()._init_backbone(config)
  611. requires_backends(self, ["natten"])
  612. self.embeddings = DinatEmbeddings(config)
  613. self.encoder = DinatEncoder(config)
  614. self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
  615. # Add layer norms to hidden states of out_features
  616. hidden_states_norms = {}
  617. for stage, num_channels in zip(self._out_features, self.channels):
  618. hidden_states_norms[stage] = nn.LayerNorm(num_channels)
  619. self.hidden_states_norms = nn.ModuleDict(hidden_states_norms)
  620. # Initialize weights and apply final processing
  621. self.post_init()
  622. def get_input_embeddings(self):
  623. return self.embeddings.patch_embeddings
  624. @auto_docstring
  625. def forward(
  626. self,
  627. pixel_values: torch.Tensor,
  628. output_hidden_states: Optional[bool] = None,
  629. output_attentions: Optional[bool] = None,
  630. return_dict: Optional[bool] = None,
  631. ) -> BackboneOutput:
  632. r"""
  633. Examples:
  634. ```python
  635. >>> from transformers import AutoImageProcessor, AutoBackbone
  636. >>> import torch
  637. >>> from PIL import Image
  638. >>> import requests
  639. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  640. >>> image = Image.open(requests.get(url, stream=True).raw)
  641. >>> processor = AutoImageProcessor.from_pretrained("shi-labs/nat-mini-in1k-224")
  642. >>> model = AutoBackbone.from_pretrained(
  643. ... "shi-labs/nat-mini-in1k-224", out_features=["stage1", "stage2", "stage3", "stage4"]
  644. ... )
  645. >>> inputs = processor(image, return_tensors="pt")
  646. >>> outputs = model(**inputs)
  647. >>> feature_maps = outputs.feature_maps
  648. >>> list(feature_maps[-1].shape)
  649. [1, 512, 7, 7]
  650. ```"""
  651. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  652. output_hidden_states = (
  653. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  654. )
  655. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  656. embedding_output = self.embeddings(pixel_values)
  657. outputs = self.encoder(
  658. embedding_output,
  659. output_attentions=output_attentions,
  660. output_hidden_states=True,
  661. output_hidden_states_before_downsampling=True,
  662. return_dict=True,
  663. )
  664. hidden_states = outputs.reshaped_hidden_states
  665. feature_maps = ()
  666. for stage, hidden_state in zip(self.stage_names, hidden_states):
  667. if stage in self.out_features:
  668. batch_size, num_channels, height, width = hidden_state.shape
  669. hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous()
  670. hidden_state = hidden_state.view(batch_size, height * width, num_channels)
  671. hidden_state = self.hidden_states_norms[stage](hidden_state)
  672. hidden_state = hidden_state.view(batch_size, height, width, num_channels)
  673. hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
  674. feature_maps += (hidden_state,)
  675. if not return_dict:
  676. output = (feature_maps,)
  677. if output_hidden_states:
  678. output += (outputs.hidden_states,)
  679. return output
  680. return BackboneOutput(
  681. feature_maps=feature_maps,
  682. hidden_states=outputs.hidden_states if output_hidden_states else None,
  683. attentions=outputs.attentions,
  684. )
  685. __all__ = ["DinatForImageClassification", "DinatModel", "DinatPreTrainedModel", "DinatBackbone"]