modeling_deit.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791
  1. # coding=utf-8
  2. # Copyright 2021 Facebook AI Research (FAIR), Ross Wightman, The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch DeiT model."""
  16. import collections.abc
  17. from dataclasses import dataclass
  18. from typing import Callable, Optional, Union
  19. import torch
  20. from torch import nn
  21. from ...activations import ACT2FN
  22. from ...modeling_layers import GradientCheckpointingLayer
  23. from ...modeling_outputs import (
  24. BaseModelOutput,
  25. BaseModelOutputWithPooling,
  26. ImageClassifierOutput,
  27. MaskedImageModelingOutput,
  28. )
  29. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  30. from ...processing_utils import Unpack
  31. from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
  32. from ...utils import ModelOutput, TransformersKwargs, auto_docstring, logging, torch_int
  33. from ...utils.generic import can_return_tuple, check_model_inputs
  34. from .configuration_deit import DeiTConfig
  35. logger = logging.get_logger(__name__)
  36. class DeiTEmbeddings(nn.Module):
  37. """
  38. Construct the CLS token, distillation token, position and patch embeddings. Optionally, also the mask token.
  39. """
  40. def __init__(self, config: DeiTConfig, use_mask_token: bool = False) -> None:
  41. super().__init__()
  42. self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
  43. self.distillation_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
  44. self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
  45. self.patch_embeddings = DeiTPatchEmbeddings(config)
  46. num_patches = self.patch_embeddings.num_patches
  47. self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 2, config.hidden_size))
  48. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  49. self.patch_size = config.patch_size
  50. def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
  51. """
  52. This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
  53. images. This method is also adapted to support torch.jit tracing and 2 class embeddings.
  54. Adapted from:
  55. - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
  56. - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
  57. """
  58. num_patches = embeddings.shape[1] - 2
  59. num_positions = self.position_embeddings.shape[1] - 2
  60. # always interpolate when tracing to ensure the exported model works for dynamic input shapes
  61. if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
  62. return self.position_embeddings
  63. class_and_dist_pos_embed = self.position_embeddings[:, :2]
  64. patch_pos_embed = self.position_embeddings[:, 2:]
  65. dim = embeddings.shape[-1]
  66. new_height = height // self.patch_size
  67. new_width = width // self.patch_size
  68. sqrt_num_positions = torch_int(num_positions**0.5)
  69. patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
  70. patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
  71. patch_pos_embed = nn.functional.interpolate(
  72. patch_pos_embed,
  73. size=(new_height, new_width),
  74. mode="bicubic",
  75. align_corners=False,
  76. )
  77. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  78. return torch.cat((class_and_dist_pos_embed, patch_pos_embed), dim=1)
  79. def forward(
  80. self,
  81. pixel_values: torch.Tensor,
  82. bool_masked_pos: Optional[torch.BoolTensor] = None,
  83. interpolate_pos_encoding: bool = False,
  84. ) -> torch.Tensor:
  85. _, _, height, width = pixel_values.shape
  86. embeddings = self.patch_embeddings(pixel_values)
  87. batch_size, seq_length, _ = embeddings.size()
  88. if bool_masked_pos is not None:
  89. mask_tokens = self.mask_token.expand(batch_size, seq_length, -1)
  90. # replace the masked visual tokens by mask_tokens
  91. mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
  92. embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
  93. cls_tokens = self.cls_token.expand(batch_size, -1, -1)
  94. distillation_tokens = self.distillation_token.expand(batch_size, -1, -1)
  95. embeddings = torch.cat((cls_tokens, distillation_tokens, embeddings), dim=1)
  96. position_embedding = self.position_embeddings
  97. if interpolate_pos_encoding:
  98. position_embedding = self.interpolate_pos_encoding(embeddings, height, width)
  99. embeddings = embeddings + position_embedding
  100. embeddings = self.dropout(embeddings)
  101. return embeddings
  102. class DeiTPatchEmbeddings(nn.Module):
  103. """
  104. This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
  105. `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
  106. Transformer.
  107. """
  108. def __init__(self, config):
  109. super().__init__()
  110. image_size, patch_size = config.image_size, config.patch_size
  111. num_channels, hidden_size = config.num_channels, config.hidden_size
  112. image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
  113. patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
  114. num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
  115. self.image_size = image_size
  116. self.patch_size = patch_size
  117. self.num_channels = num_channels
  118. self.num_patches = num_patches
  119. self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
  120. def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
  121. batch_size, num_channels, height, width = pixel_values.shape
  122. if num_channels != self.num_channels:
  123. raise ValueError(
  124. "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
  125. )
  126. x = self.projection(pixel_values).flatten(2).transpose(1, 2)
  127. return x
  128. # Copied from transformers.models.vit.modeling_vit.eager_attention_forward
  129. def eager_attention_forward(
  130. module: nn.Module,
  131. query: torch.Tensor,
  132. key: torch.Tensor,
  133. value: torch.Tensor,
  134. attention_mask: Optional[torch.Tensor],
  135. scaling: float,
  136. dropout: float = 0.0,
  137. **kwargs,
  138. ):
  139. # Take the dot product between "query" and "key" to get the raw attention scores.
  140. attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
  141. # Normalize the attention scores to probabilities.
  142. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  143. # This is actually dropping out entire tokens to attend to, which might
  144. # seem a bit unusual, but is taken from the original Transformer paper.
  145. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  146. # Mask heads if we want to
  147. if attention_mask is not None:
  148. attn_weights = attn_weights * attention_mask
  149. attn_output = torch.matmul(attn_weights, value)
  150. attn_output = attn_output.transpose(1, 2).contiguous()
  151. return attn_output, attn_weights
  152. # Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->DeiT
  153. class DeiTSelfAttention(nn.Module):
  154. def __init__(self, config: DeiTConfig):
  155. super().__init__()
  156. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  157. raise ValueError(
  158. f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
  159. f"heads {config.num_attention_heads}."
  160. )
  161. self.config = config
  162. self.num_attention_heads = config.num_attention_heads
  163. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  164. self.all_head_size = self.num_attention_heads * self.attention_head_size
  165. self.dropout_prob = config.attention_probs_dropout_prob
  166. self.scaling = self.attention_head_size**-0.5
  167. self.is_causal = False
  168. self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
  169. self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
  170. self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
  171. def forward(
  172. self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None
  173. ) -> tuple[torch.Tensor, torch.Tensor]:
  174. batch_size = hidden_states.shape[0]
  175. new_shape = batch_size, -1, self.num_attention_heads, self.attention_head_size
  176. key_layer = self.key(hidden_states).view(*new_shape).transpose(1, 2)
  177. value_layer = self.value(hidden_states).view(*new_shape).transpose(1, 2)
  178. query_layer = self.query(hidden_states).view(*new_shape).transpose(1, 2)
  179. attention_interface: Callable = eager_attention_forward
  180. if self.config._attn_implementation != "eager":
  181. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  182. context_layer, attention_probs = attention_interface(
  183. self,
  184. query_layer,
  185. key_layer,
  186. value_layer,
  187. head_mask,
  188. is_causal=self.is_causal,
  189. scaling=self.scaling,
  190. dropout=0.0 if not self.training else self.dropout_prob,
  191. )
  192. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  193. context_layer = context_layer.reshape(new_context_layer_shape)
  194. return context_layer, attention_probs
  195. # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->DeiT
  196. class DeiTSelfOutput(nn.Module):
  197. """
  198. The residual connection is defined in DeiTLayer instead of here (as is the case with other models), due to the
  199. layernorm applied before each block.
  200. """
  201. def __init__(self, config: DeiTConfig):
  202. super().__init__()
  203. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  204. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  205. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  206. hidden_states = self.dense(hidden_states)
  207. hidden_states = self.dropout(hidden_states)
  208. return hidden_states
  209. # Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->DeiT
  210. class DeiTAttention(nn.Module):
  211. def __init__(self, config: DeiTConfig):
  212. super().__init__()
  213. self.attention = DeiTSelfAttention(config)
  214. self.output = DeiTSelfOutput(config)
  215. self.pruned_heads = set()
  216. def prune_heads(self, heads: set[int]):
  217. if len(heads) == 0:
  218. return
  219. heads, index = find_pruneable_heads_and_indices(
  220. heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
  221. )
  222. # Prune linear layers
  223. self.attention.query = prune_linear_layer(self.attention.query, index)
  224. self.attention.key = prune_linear_layer(self.attention.key, index)
  225. self.attention.value = prune_linear_layer(self.attention.value, index)
  226. self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
  227. # Update hyper params and store pruned heads
  228. self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
  229. self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
  230. self.pruned_heads = self.pruned_heads.union(heads)
  231. def forward(self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
  232. self_attn_output, _ = self.attention(hidden_states, head_mask)
  233. output = self.output(self_attn_output, hidden_states)
  234. return output
  235. # Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->DeiT
  236. class DeiTIntermediate(nn.Module):
  237. def __init__(self, config: DeiTConfig):
  238. super().__init__()
  239. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  240. if isinstance(config.hidden_act, str):
  241. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  242. else:
  243. self.intermediate_act_fn = config.hidden_act
  244. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  245. hidden_states = self.dense(hidden_states)
  246. hidden_states = self.intermediate_act_fn(hidden_states)
  247. return hidden_states
  248. # Copied from transformers.models.vit.modeling_vit.ViTOutput with ViT->DeiT
  249. class DeiTOutput(nn.Module):
  250. def __init__(self, config: DeiTConfig):
  251. super().__init__()
  252. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  253. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  254. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  255. hidden_states = self.dense(hidden_states)
  256. hidden_states = self.dropout(hidden_states)
  257. hidden_states = hidden_states + input_tensor
  258. return hidden_states
  259. # Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->DeiT,VIT->DEIT
  260. class DeiTLayer(GradientCheckpointingLayer):
  261. """This corresponds to the Block class in the timm implementation."""
  262. def __init__(self, config: DeiTConfig):
  263. super().__init__()
  264. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  265. self.seq_len_dim = 1
  266. self.attention = DeiTAttention(config)
  267. self.intermediate = DeiTIntermediate(config)
  268. self.output = DeiTOutput(config)
  269. self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  270. self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  271. def forward(self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
  272. hidden_states_norm = self.layernorm_before(hidden_states)
  273. attention_output = self.attention(hidden_states_norm, head_mask)
  274. # first residual connection
  275. hidden_states = attention_output + hidden_states
  276. # in DeiT, layernorm is also applied after self-attention
  277. layer_output = self.layernorm_after(hidden_states)
  278. layer_output = self.intermediate(layer_output)
  279. # second residual connection is done here
  280. layer_output = self.output(layer_output, hidden_states)
  281. return layer_output
  282. # Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->DeiT
  283. class DeiTEncoder(nn.Module):
  284. def __init__(self, config: DeiTConfig):
  285. super().__init__()
  286. self.config = config
  287. self.layer = nn.ModuleList([DeiTLayer(config) for _ in range(config.num_hidden_layers)])
  288. self.gradient_checkpointing = False
  289. def forward(self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None) -> BaseModelOutput:
  290. for i, layer_module in enumerate(self.layer):
  291. layer_head_mask = head_mask[i] if head_mask is not None else None
  292. hidden_states = layer_module(hidden_states, layer_head_mask)
  293. return BaseModelOutput(last_hidden_state=hidden_states)
  294. @auto_docstring
  295. class DeiTPreTrainedModel(PreTrainedModel):
  296. config: DeiTConfig
  297. base_model_prefix = "deit"
  298. main_input_name = "pixel_values"
  299. supports_gradient_checkpointing = True
  300. _no_split_modules = ["DeiTLayer"]
  301. _supports_sdpa = True
  302. _supports_flash_attn = True
  303. _supports_flex_attn = True
  304. _supports_attention_backend = True
  305. _can_record_outputs = {
  306. "hidden_states": DeiTLayer,
  307. "attentions": DeiTSelfAttention,
  308. }
  309. def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
  310. """Initialize the weights"""
  311. if isinstance(module, (nn.Linear, nn.Conv2d)):
  312. # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
  313. # `trunc_normal_cpu` not implemented in `half` issues
  314. module.weight.data = nn.init.trunc_normal_(
  315. module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
  316. ).to(module.weight.dtype)
  317. if module.bias is not None:
  318. module.bias.data.zero_()
  319. elif isinstance(module, nn.LayerNorm):
  320. module.bias.data.zero_()
  321. module.weight.data.fill_(1.0)
  322. elif isinstance(module, DeiTEmbeddings):
  323. module.cls_token.data.zero_()
  324. module.position_embeddings.data.zero_()
  325. module.distillation_token.data.zero_()
  326. if module.mask_token is not None:
  327. module.mask_token.data.zero_()
  328. @auto_docstring
  329. class DeiTModel(DeiTPreTrainedModel):
  330. def __init__(self, config: DeiTConfig, add_pooling_layer: bool = True, use_mask_token: bool = False) -> None:
  331. r"""
  332. add_pooling_layer (bool, *optional*, defaults to `True`):
  333. Whether to add a pooling layer
  334. use_mask_token (`bool`, *optional*, defaults to `False`):
  335. Whether to use a mask token for masked image modeling.
  336. """
  337. super().__init__(config)
  338. self.config = config
  339. self.embeddings = DeiTEmbeddings(config, use_mask_token=use_mask_token)
  340. self.encoder = DeiTEncoder(config)
  341. self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  342. self.pooler = DeiTPooler(config) if add_pooling_layer else None
  343. # Initialize weights and apply final processing
  344. self.post_init()
  345. def get_input_embeddings(self) -> DeiTPatchEmbeddings:
  346. return self.embeddings.patch_embeddings
  347. def _prune_heads(self, heads_to_prune):
  348. """
  349. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  350. class PreTrainedModel
  351. """
  352. for layer, heads in heads_to_prune.items():
  353. self.encoder.layer[layer].attention.prune_heads(heads)
  354. @check_model_inputs(tie_last_hidden_states=False)
  355. @auto_docstring
  356. def forward(
  357. self,
  358. pixel_values: Optional[torch.Tensor] = None,
  359. bool_masked_pos: Optional[torch.BoolTensor] = None,
  360. head_mask: Optional[torch.Tensor] = None,
  361. interpolate_pos_encoding: bool = False,
  362. **kwargs: Unpack[TransformersKwargs],
  363. ) -> BaseModelOutputWithPooling:
  364. r"""
  365. bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
  366. Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
  367. """
  368. if pixel_values is None:
  369. raise ValueError("You have to specify pixel_values")
  370. # Prepare head mask if needed
  371. # 1.0 in head_mask indicate we keep the head
  372. # attention_probs has shape bsz x n_heads x N x N
  373. # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
  374. # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
  375. head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
  376. # TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?)
  377. expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype
  378. if pixel_values.dtype != expected_dtype:
  379. pixel_values = pixel_values.to(expected_dtype)
  380. embedding_output = self.embeddings(
  381. pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
  382. )
  383. encoder_outputs: BaseModelOutput = self.encoder(embedding_output, head_mask=head_mask)
  384. sequence_output = encoder_outputs.last_hidden_state
  385. sequence_output = self.layernorm(sequence_output)
  386. pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
  387. return BaseModelOutputWithPooling(
  388. last_hidden_state=sequence_output,
  389. pooler_output=pooled_output,
  390. )
  391. # Copied from transformers.models.vit.modeling_vit.ViTPooler with ViT->DeiT
  392. class DeiTPooler(nn.Module):
  393. def __init__(self, config: DeiTConfig):
  394. super().__init__()
  395. self.dense = nn.Linear(config.hidden_size, config.pooler_output_size)
  396. self.activation = ACT2FN[config.pooler_act]
  397. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  398. # We "pool" the model by simply taking the hidden state corresponding
  399. # to the first token.
  400. first_token_tensor = hidden_states[:, 0]
  401. pooled_output = self.dense(first_token_tensor)
  402. pooled_output = self.activation(pooled_output)
  403. return pooled_output
  404. @auto_docstring(
  405. custom_intro="""
  406. DeiT Model with a decoder on top for masked image modeling, as proposed in [SimMIM](https://huggingface.co/papers/2111.09886).
  407. <Tip>
  408. Note that we provide a script to pre-train this model on custom data in our [examples
  409. directory](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-pretraining).
  410. </Tip>
  411. """
  412. )
  413. class DeiTForMaskedImageModeling(DeiTPreTrainedModel):
  414. def __init__(self, config: DeiTConfig) -> None:
  415. super().__init__(config)
  416. self.deit = DeiTModel(config, add_pooling_layer=False, use_mask_token=True)
  417. self.decoder = nn.Sequential(
  418. nn.Conv2d(
  419. in_channels=config.hidden_size,
  420. out_channels=config.encoder_stride**2 * config.num_channels,
  421. kernel_size=1,
  422. ),
  423. nn.PixelShuffle(config.encoder_stride),
  424. )
  425. # Initialize weights and apply final processing
  426. self.post_init()
  427. @can_return_tuple
  428. @auto_docstring
  429. def forward(
  430. self,
  431. pixel_values: Optional[torch.Tensor] = None,
  432. bool_masked_pos: Optional[torch.BoolTensor] = None,
  433. head_mask: Optional[torch.Tensor] = None,
  434. interpolate_pos_encoding: bool = False,
  435. **kwargs: Unpack[TransformersKwargs],
  436. ) -> MaskedImageModelingOutput:
  437. r"""
  438. bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
  439. Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
  440. Examples:
  441. ```python
  442. >>> from transformers import AutoImageProcessor, DeiTForMaskedImageModeling
  443. >>> import torch
  444. >>> from PIL import Image
  445. >>> import requests
  446. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  447. >>> image = Image.open(requests.get(url, stream=True).raw)
  448. >>> image_processor = AutoImageProcessor.from_pretrained("facebook/deit-base-distilled-patch16-224")
  449. >>> model = DeiTForMaskedImageModeling.from_pretrained("facebook/deit-base-distilled-patch16-224")
  450. >>> num_patches = (model.config.image_size // model.config.patch_size) ** 2
  451. >>> pixel_values = image_processor(images=image, return_tensors="pt").pixel_values
  452. >>> # create random boolean mask of shape (batch_size, num_patches)
  453. >>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()
  454. >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
  455. >>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction
  456. >>> list(reconstructed_pixel_values.shape)
  457. [1, 3, 224, 224]
  458. ```"""
  459. outputs: BaseModelOutputWithPooling = self.deit(
  460. pixel_values,
  461. bool_masked_pos=bool_masked_pos,
  462. head_mask=head_mask,
  463. interpolate_pos_encoding=interpolate_pos_encoding,
  464. **kwargs,
  465. )
  466. sequence_output = outputs.last_hidden_state
  467. # Reshape to (batch_size, num_channels, height, width)
  468. sequence_output = sequence_output[:, 1:-1]
  469. batch_size, sequence_length, num_channels = sequence_output.shape
  470. height = width = int(sequence_length**0.5)
  471. sequence_output = sequence_output.permute(0, 2, 1).reshape(batch_size, num_channels, height, width)
  472. # Reconstruct pixel values
  473. reconstructed_pixel_values = self.decoder(sequence_output)
  474. masked_im_loss = None
  475. if bool_masked_pos is not None:
  476. size = self.config.image_size // self.config.patch_size
  477. bool_masked_pos = bool_masked_pos.reshape(-1, size, size)
  478. mask = (
  479. bool_masked_pos.repeat_interleave(self.config.patch_size, 1)
  480. .repeat_interleave(self.config.patch_size, 2)
  481. .unsqueeze(1)
  482. .contiguous()
  483. )
  484. reconstruction_loss = nn.functional.l1_loss(pixel_values, reconstructed_pixel_values, reduction="none")
  485. masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels
  486. return MaskedImageModelingOutput(
  487. loss=masked_im_loss,
  488. reconstruction=reconstructed_pixel_values,
  489. hidden_states=outputs.hidden_states,
  490. attentions=outputs.attentions,
  491. )
  492. @auto_docstring(
  493. custom_intro="""
  494. DeiT Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
  495. the [CLS] token) e.g. for ImageNet.
  496. """
  497. )
  498. class DeiTForImageClassification(DeiTPreTrainedModel):
  499. def __init__(self, config: DeiTConfig) -> None:
  500. super().__init__(config)
  501. self.num_labels = config.num_labels
  502. self.deit = DeiTModel(config, add_pooling_layer=False)
  503. # Classifier head
  504. self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
  505. # Initialize weights and apply final processing
  506. self.post_init()
  507. @can_return_tuple
  508. @auto_docstring
  509. def forward(
  510. self,
  511. pixel_values: Optional[torch.Tensor] = None,
  512. head_mask: Optional[torch.Tensor] = None,
  513. labels: Optional[torch.Tensor] = None,
  514. interpolate_pos_encoding: bool = False,
  515. **kwargs: Unpack[TransformersKwargs],
  516. ) -> ImageClassifierOutput:
  517. r"""
  518. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  519. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  520. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  521. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  522. Examples:
  523. ```python
  524. >>> from transformers import AutoImageProcessor, DeiTForImageClassification
  525. >>> import torch
  526. >>> from PIL import Image
  527. >>> import requests
  528. >>> torch.manual_seed(3) # doctest: +IGNORE_RESULT
  529. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  530. >>> image = Image.open(requests.get(url, stream=True).raw)
  531. >>> # note: we are loading a DeiTForImageClassificationWithTeacher from the hub here,
  532. >>> # so the head will be randomly initialized, hence the predictions will be random
  533. >>> image_processor = AutoImageProcessor.from_pretrained("facebook/deit-base-distilled-patch16-224")
  534. >>> model = DeiTForImageClassification.from_pretrained("facebook/deit-base-distilled-patch16-224")
  535. >>> inputs = image_processor(images=image, return_tensors="pt")
  536. >>> outputs = model(**inputs)
  537. >>> logits = outputs.logits
  538. >>> # model predicts one of the 1000 ImageNet classes
  539. >>> predicted_class_idx = logits.argmax(-1).item()
  540. >>> print("Predicted class:", model.config.id2label[predicted_class_idx])
  541. Predicted class: Polaroid camera, Polaroid Land camera
  542. ```"""
  543. outputs: BaseModelOutputWithPooling = self.deit(
  544. pixel_values,
  545. head_mask=head_mask,
  546. interpolate_pos_encoding=interpolate_pos_encoding,
  547. **kwargs,
  548. )
  549. sequence_output = outputs.last_hidden_state
  550. logits = self.classifier(sequence_output[:, 0, :])
  551. # we don't use the distillation token
  552. loss = None
  553. if labels is not None:
  554. loss = self.loss_function(labels, logits, self.config, **kwargs)
  555. return ImageClassifierOutput(
  556. loss=loss,
  557. logits=logits,
  558. hidden_states=outputs.hidden_states,
  559. attentions=outputs.attentions,
  560. )
  561. @dataclass
  562. @auto_docstring(
  563. custom_intro="""
  564. Output type of [`DeiTForImageClassificationWithTeacher`].
  565. """
  566. )
  567. class DeiTForImageClassificationWithTeacherOutput(ModelOutput):
  568. r"""
  569. logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
  570. Prediction scores as the average of the cls_logits and distillation logits.
  571. cls_logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
  572. Prediction scores of the classification head (i.e. the linear layer on top of the final hidden state of the
  573. class token).
  574. distillation_logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
  575. Prediction scores of the distillation head (i.e. the linear layer on top of the final hidden state of the
  576. distillation token).
  577. """
  578. logits: Optional[torch.FloatTensor] = None
  579. cls_logits: Optional[torch.FloatTensor] = None
  580. distillation_logits: Optional[torch.FloatTensor] = None
  581. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  582. attentions: Optional[tuple[torch.FloatTensor]] = None
  583. @auto_docstring(
  584. custom_intro="""
  585. DeiT Model transformer with image classification heads on top (a linear layer on top of the final hidden state of
  586. the [CLS] token and a linear layer on top of the final hidden state of the distillation token) e.g. for ImageNet.
  587. .. warning::
  588. This model supports inference-only. Fine-tuning with distillation (i.e. with a teacher) is not yet
  589. supported.
  590. """
  591. )
  592. class DeiTForImageClassificationWithTeacher(DeiTPreTrainedModel):
  593. def __init__(self, config: DeiTConfig) -> None:
  594. super().__init__(config)
  595. self.num_labels = config.num_labels
  596. self.deit = DeiTModel(config, add_pooling_layer=False)
  597. # Classifier heads
  598. self.cls_classifier = (
  599. nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
  600. )
  601. self.distillation_classifier = (
  602. nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
  603. )
  604. # Initialize weights and apply final processing
  605. self.post_init()
  606. @can_return_tuple
  607. @auto_docstring
  608. def forward(
  609. self,
  610. pixel_values: Optional[torch.Tensor] = None,
  611. head_mask: Optional[torch.Tensor] = None,
  612. interpolate_pos_encoding: bool = False,
  613. **kwargs: Unpack[TransformersKwargs],
  614. ) -> DeiTForImageClassificationWithTeacherOutput:
  615. outputs: BaseModelOutputWithPooling = self.deit(
  616. pixel_values,
  617. head_mask=head_mask,
  618. interpolate_pos_encoding=interpolate_pos_encoding,
  619. **kwargs,
  620. )
  621. sequence_output = outputs.last_hidden_state
  622. cls_logits = self.cls_classifier(sequence_output[:, 0, :])
  623. distillation_logits = self.distillation_classifier(sequence_output[:, 1, :])
  624. # during inference, return the average of both classifier predictions
  625. logits = (cls_logits + distillation_logits) / 2
  626. return DeiTForImageClassificationWithTeacherOutput(
  627. logits=logits,
  628. cls_logits=cls_logits,
  629. distillation_logits=distillation_logits,
  630. hidden_states=outputs.hidden_states,
  631. attentions=outputs.attentions,
  632. )
  633. __all__ = [
  634. "DeiTForImageClassification",
  635. "DeiTForImageClassificationWithTeacher",
  636. "DeiTForMaskedImageModeling",
  637. "DeiTModel",
  638. "DeiTPreTrainedModel",
  639. ]