modeling_vit.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691
  1. # coding=utf-8
  2. # Copyright 2021 Google AI, Ross Wightman, The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch ViT model."""
  16. import collections.abc
  17. import math
  18. from typing import Callable, Optional, Union
  19. import torch
  20. from torch import nn
  21. from ...activations import ACT2FN
  22. from ...modeling_layers import GradientCheckpointingLayer
  23. from ...modeling_outputs import (
  24. BaseModelOutput,
  25. BaseModelOutputWithPooling,
  26. ImageClassifierOutput,
  27. MaskedImageModelingOutput,
  28. )
  29. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  30. from ...processing_utils import Unpack
  31. from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
  32. from ...utils import TransformersKwargs, auto_docstring, logging, torch_int
  33. from ...utils.generic import can_return_tuple, check_model_inputs
  34. from .configuration_vit import ViTConfig
  35. logger = logging.get_logger(__name__)
  36. class ViTEmbeddings(nn.Module):
  37. """
  38. Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
  39. """
  40. def __init__(self, config: ViTConfig, use_mask_token: bool = False):
  41. super().__init__()
  42. self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
  43. self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
  44. self.patch_embeddings = ViTPatchEmbeddings(config)
  45. num_patches = self.patch_embeddings.num_patches
  46. self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size))
  47. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  48. self.patch_size = config.patch_size
  49. self.config = config
  50. def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
  51. """
  52. This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
  53. images. This method is also adapted to support torch.jit tracing.
  54. Adapted from:
  55. - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
  56. - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
  57. """
  58. num_patches = embeddings.shape[1] - 1
  59. num_positions = self.position_embeddings.shape[1] - 1
  60. # always interpolate when tracing to ensure the exported model works for dynamic input shapes
  61. if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
  62. return self.position_embeddings
  63. class_pos_embed = self.position_embeddings[:, :1]
  64. patch_pos_embed = self.position_embeddings[:, 1:]
  65. dim = embeddings.shape[-1]
  66. new_height = height // self.patch_size
  67. new_width = width // self.patch_size
  68. sqrt_num_positions = torch_int(num_positions**0.5)
  69. patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
  70. patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
  71. patch_pos_embed = nn.functional.interpolate(
  72. patch_pos_embed,
  73. size=(new_height, new_width),
  74. mode="bicubic",
  75. align_corners=False,
  76. )
  77. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  78. return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
  79. def forward(
  80. self,
  81. pixel_values: torch.Tensor,
  82. bool_masked_pos: Optional[torch.BoolTensor] = None,
  83. interpolate_pos_encoding: bool = False,
  84. ) -> torch.Tensor:
  85. batch_size, num_channels, height, width = pixel_values.shape
  86. embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
  87. if bool_masked_pos is not None:
  88. seq_length = embeddings.shape[1]
  89. mask_tokens = self.mask_token.expand(batch_size, seq_length, -1)
  90. # replace the masked visual tokens by mask_tokens
  91. mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
  92. embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
  93. # add the [CLS] token to the embedded patch tokens
  94. cls_tokens = self.cls_token.expand(batch_size, -1, -1)
  95. embeddings = torch.cat((cls_tokens, embeddings), dim=1)
  96. # add positional encoding to each token
  97. if interpolate_pos_encoding:
  98. embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
  99. else:
  100. embeddings = embeddings + self.position_embeddings
  101. embeddings = self.dropout(embeddings)
  102. return embeddings
  103. class ViTPatchEmbeddings(nn.Module):
  104. """
  105. This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
  106. `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
  107. Transformer.
  108. """
  109. def __init__(self, config: ViTConfig):
  110. super().__init__()
  111. image_size, patch_size = config.image_size, config.patch_size
  112. num_channels, hidden_size = config.num_channels, config.hidden_size
  113. image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
  114. patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
  115. num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
  116. self.image_size = image_size
  117. self.patch_size = patch_size
  118. self.num_channels = num_channels
  119. self.num_patches = num_patches
  120. self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
  121. def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
  122. batch_size, num_channels, height, width = pixel_values.shape
  123. if num_channels != self.num_channels:
  124. raise ValueError(
  125. "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
  126. f" Expected {self.num_channels} but got {num_channels}."
  127. )
  128. if not interpolate_pos_encoding:
  129. if height != self.image_size[0] or width != self.image_size[1]:
  130. raise ValueError(
  131. f"Input image size ({height}*{width}) doesn't match model"
  132. f" ({self.image_size[0]}*{self.image_size[1]})."
  133. )
  134. embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
  135. return embeddings
  136. def eager_attention_forward(
  137. module: nn.Module,
  138. query: torch.Tensor,
  139. key: torch.Tensor,
  140. value: torch.Tensor,
  141. attention_mask: Optional[torch.Tensor],
  142. scaling: float,
  143. dropout: float = 0.0,
  144. **kwargs,
  145. ):
  146. # Take the dot product between "query" and "key" to get the raw attention scores.
  147. attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
  148. # Normalize the attention scores to probabilities.
  149. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  150. # This is actually dropping out entire tokens to attend to, which might
  151. # seem a bit unusual, but is taken from the original Transformer paper.
  152. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  153. # Mask heads if we want to
  154. if attention_mask is not None:
  155. attn_weights = attn_weights * attention_mask
  156. attn_output = torch.matmul(attn_weights, value)
  157. attn_output = attn_output.transpose(1, 2).contiguous()
  158. return attn_output, attn_weights
  159. class ViTSelfAttention(nn.Module):
  160. def __init__(self, config: ViTConfig):
  161. super().__init__()
  162. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  163. raise ValueError(
  164. f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
  165. f"heads {config.num_attention_heads}."
  166. )
  167. self.config = config
  168. self.num_attention_heads = config.num_attention_heads
  169. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  170. self.all_head_size = self.num_attention_heads * self.attention_head_size
  171. self.dropout_prob = config.attention_probs_dropout_prob
  172. self.scaling = self.attention_head_size**-0.5
  173. self.is_causal = False
  174. self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
  175. self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
  176. self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
  177. def forward(
  178. self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None
  179. ) -> tuple[torch.Tensor, torch.Tensor]:
  180. batch_size = hidden_states.shape[0]
  181. new_shape = batch_size, -1, self.num_attention_heads, self.attention_head_size
  182. key_layer = self.key(hidden_states).view(*new_shape).transpose(1, 2)
  183. value_layer = self.value(hidden_states).view(*new_shape).transpose(1, 2)
  184. query_layer = self.query(hidden_states).view(*new_shape).transpose(1, 2)
  185. attention_interface: Callable = eager_attention_forward
  186. if self.config._attn_implementation != "eager":
  187. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  188. context_layer, attention_probs = attention_interface(
  189. self,
  190. query_layer,
  191. key_layer,
  192. value_layer,
  193. head_mask,
  194. is_causal=self.is_causal,
  195. scaling=self.scaling,
  196. dropout=0.0 if not self.training else self.dropout_prob,
  197. )
  198. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  199. context_layer = context_layer.reshape(new_context_layer_shape)
  200. return context_layer, attention_probs
  201. class ViTSelfOutput(nn.Module):
  202. """
  203. The residual connection is defined in ViTLayer instead of here (as is the case with other models), due to the
  204. layernorm applied before each block.
  205. """
  206. def __init__(self, config: ViTConfig):
  207. super().__init__()
  208. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  209. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  210. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  211. hidden_states = self.dense(hidden_states)
  212. hidden_states = self.dropout(hidden_states)
  213. return hidden_states
  214. class ViTAttention(nn.Module):
  215. def __init__(self, config: ViTConfig):
  216. super().__init__()
  217. self.attention = ViTSelfAttention(config)
  218. self.output = ViTSelfOutput(config)
  219. self.pruned_heads = set()
  220. def prune_heads(self, heads: set[int]):
  221. if len(heads) == 0:
  222. return
  223. heads, index = find_pruneable_heads_and_indices(
  224. heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
  225. )
  226. # Prune linear layers
  227. self.attention.query = prune_linear_layer(self.attention.query, index)
  228. self.attention.key = prune_linear_layer(self.attention.key, index)
  229. self.attention.value = prune_linear_layer(self.attention.value, index)
  230. self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
  231. # Update hyper params and store pruned heads
  232. self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
  233. self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
  234. self.pruned_heads = self.pruned_heads.union(heads)
  235. def forward(self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
  236. self_attn_output, _ = self.attention(hidden_states, head_mask)
  237. output = self.output(self_attn_output, hidden_states)
  238. return output
  239. class ViTIntermediate(nn.Module):
  240. def __init__(self, config: ViTConfig):
  241. super().__init__()
  242. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  243. if isinstance(config.hidden_act, str):
  244. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  245. else:
  246. self.intermediate_act_fn = config.hidden_act
  247. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  248. hidden_states = self.dense(hidden_states)
  249. hidden_states = self.intermediate_act_fn(hidden_states)
  250. return hidden_states
  251. class ViTOutput(nn.Module):
  252. def __init__(self, config: ViTConfig):
  253. super().__init__()
  254. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  255. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  256. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  257. hidden_states = self.dense(hidden_states)
  258. hidden_states = self.dropout(hidden_states)
  259. hidden_states = hidden_states + input_tensor
  260. return hidden_states
  261. class ViTLayer(GradientCheckpointingLayer):
  262. """This corresponds to the Block class in the timm implementation."""
  263. def __init__(self, config: ViTConfig):
  264. super().__init__()
  265. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  266. self.seq_len_dim = 1
  267. self.attention = ViTAttention(config)
  268. self.intermediate = ViTIntermediate(config)
  269. self.output = ViTOutput(config)
  270. self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  271. self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  272. def forward(self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
  273. hidden_states_norm = self.layernorm_before(hidden_states)
  274. attention_output = self.attention(hidden_states_norm, head_mask)
  275. # first residual connection
  276. hidden_states = attention_output + hidden_states
  277. # in ViT, layernorm is also applied after self-attention
  278. layer_output = self.layernorm_after(hidden_states)
  279. layer_output = self.intermediate(layer_output)
  280. # second residual connection is done here
  281. layer_output = self.output(layer_output, hidden_states)
  282. return layer_output
  283. class ViTEncoder(nn.Module):
  284. def __init__(self, config: ViTConfig):
  285. super().__init__()
  286. self.config = config
  287. self.layer = nn.ModuleList([ViTLayer(config) for _ in range(config.num_hidden_layers)])
  288. self.gradient_checkpointing = False
  289. def forward(self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None) -> BaseModelOutput:
  290. for i, layer_module in enumerate(self.layer):
  291. layer_head_mask = head_mask[i] if head_mask is not None else None
  292. hidden_states = layer_module(hidden_states, layer_head_mask)
  293. return BaseModelOutput(last_hidden_state=hidden_states)
  294. @auto_docstring
  295. class ViTPreTrainedModel(PreTrainedModel):
  296. config: ViTConfig
  297. base_model_prefix = "vit"
  298. main_input_name = "pixel_values"
  299. supports_gradient_checkpointing = True
  300. _no_split_modules = ["ViTEmbeddings", "ViTLayer"]
  301. _supports_sdpa = True
  302. _supports_flash_attn = True
  303. _supports_flex_attn = True
  304. _supports_attention_backend = True
  305. _can_record_outputs = {
  306. "hidden_states": ViTLayer,
  307. "attentions": ViTSelfAttention,
  308. }
  309. def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]):
  310. """Initialize the weights"""
  311. if isinstance(module, (nn.Linear, nn.Conv2d)):
  312. # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
  313. # `trunc_normal_cpu` not implemented in `half` issues
  314. module.weight.data = nn.init.trunc_normal_(
  315. module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
  316. ).to(module.weight.dtype)
  317. if module.bias is not None:
  318. module.bias.data.zero_()
  319. elif isinstance(module, nn.LayerNorm):
  320. module.bias.data.zero_()
  321. module.weight.data.fill_(1.0)
  322. elif isinstance(module, ViTEmbeddings):
  323. module.position_embeddings.data = nn.init.trunc_normal_(
  324. module.position_embeddings.data.to(torch.float32),
  325. mean=0.0,
  326. std=self.config.initializer_range,
  327. ).to(module.position_embeddings.dtype)
  328. module.cls_token.data = nn.init.trunc_normal_(
  329. module.cls_token.data.to(torch.float32),
  330. mean=0.0,
  331. std=self.config.initializer_range,
  332. ).to(module.cls_token.dtype)
  333. if module.mask_token is not None:
  334. module.mask_token.data.zero_()
  335. @auto_docstring
  336. class ViTModel(ViTPreTrainedModel):
  337. def __init__(self, config: ViTConfig, add_pooling_layer: bool = True, use_mask_token: bool = False):
  338. r"""
  339. add_pooling_layer (bool, *optional*, defaults to `True`):
  340. Whether to add a pooling layer
  341. use_mask_token (`bool`, *optional*, defaults to `False`):
  342. Whether to use a mask token for masked image modeling.
  343. """
  344. super().__init__(config)
  345. self.config = config
  346. self.embeddings = ViTEmbeddings(config, use_mask_token=use_mask_token)
  347. self.encoder = ViTEncoder(config)
  348. self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  349. self.pooler = ViTPooler(config) if add_pooling_layer else None
  350. # Initialize weights and apply final processing
  351. self.post_init()
  352. def get_input_embeddings(self) -> ViTPatchEmbeddings:
  353. return self.embeddings.patch_embeddings
  354. def _prune_heads(self, heads_to_prune: dict[int, list[int]]):
  355. """
  356. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  357. class PreTrainedModel
  358. """
  359. for layer, heads in heads_to_prune.items():
  360. self.encoder.layer[layer].attention.prune_heads(heads)
  361. @check_model_inputs(tie_last_hidden_states=False)
  362. @auto_docstring
  363. def forward(
  364. self,
  365. pixel_values: Optional[torch.Tensor] = None,
  366. bool_masked_pos: Optional[torch.BoolTensor] = None,
  367. head_mask: Optional[torch.Tensor] = None,
  368. interpolate_pos_encoding: Optional[bool] = None,
  369. **kwargs: Unpack[TransformersKwargs],
  370. ) -> BaseModelOutputWithPooling:
  371. r"""
  372. bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
  373. Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
  374. """
  375. if pixel_values is None:
  376. raise ValueError("You have to specify pixel_values")
  377. # Prepare head mask if needed
  378. # 1.0 in head_mask indicate we keep the head
  379. # attention_probs has shape bsz x n_heads x N x N
  380. # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
  381. # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
  382. head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
  383. # TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?)
  384. expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype
  385. if pixel_values.dtype != expected_dtype:
  386. pixel_values = pixel_values.to(expected_dtype)
  387. embedding_output = self.embeddings(
  388. pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
  389. )
  390. encoder_outputs: BaseModelOutput = self.encoder(embedding_output, head_mask=head_mask)
  391. sequence_output = encoder_outputs.last_hidden_state
  392. sequence_output = self.layernorm(sequence_output)
  393. pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
  394. return BaseModelOutputWithPooling(last_hidden_state=sequence_output, pooler_output=pooled_output)
  395. class ViTPooler(nn.Module):
  396. def __init__(self, config: ViTConfig):
  397. super().__init__()
  398. self.dense = nn.Linear(config.hidden_size, config.pooler_output_size)
  399. self.activation = ACT2FN[config.pooler_act]
  400. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  401. # We "pool" the model by simply taking the hidden state corresponding
  402. # to the first token.
  403. first_token_tensor = hidden_states[:, 0]
  404. pooled_output = self.dense(first_token_tensor)
  405. pooled_output = self.activation(pooled_output)
  406. return pooled_output
  407. @auto_docstring(
  408. custom_intro="""
  409. ViT Model with a decoder on top for masked image modeling, as proposed in [SimMIM](https://huggingface.co/papers/2111.09886).
  410. <Tip>
  411. Note that we provide a script to pre-train this model on custom data in our [examples
  412. directory](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-pretraining).
  413. </Tip>
  414. """
  415. )
  416. class ViTForMaskedImageModeling(ViTPreTrainedModel):
  417. def __init__(self, config: ViTConfig):
  418. super().__init__(config)
  419. self.vit = ViTModel(config, add_pooling_layer=False, use_mask_token=True)
  420. self.decoder = nn.Sequential(
  421. nn.Conv2d(
  422. in_channels=config.hidden_size,
  423. out_channels=config.encoder_stride**2 * config.num_channels,
  424. kernel_size=1,
  425. ),
  426. nn.PixelShuffle(config.encoder_stride),
  427. )
  428. # Initialize weights and apply final processing
  429. self.post_init()
  430. @can_return_tuple
  431. @auto_docstring
  432. def forward(
  433. self,
  434. pixel_values: Optional[torch.Tensor] = None,
  435. bool_masked_pos: Optional[torch.BoolTensor] = None,
  436. head_mask: Optional[torch.Tensor] = None,
  437. interpolate_pos_encoding: Optional[bool] = None,
  438. **kwargs: Unpack[TransformersKwargs],
  439. ) -> MaskedImageModelingOutput:
  440. r"""
  441. bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
  442. Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
  443. Examples:
  444. ```python
  445. >>> from transformers import AutoImageProcessor, ViTForMaskedImageModeling
  446. >>> import torch
  447. >>> from PIL import Image
  448. >>> import requests
  449. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  450. >>> image = Image.open(requests.get(url, stream=True).raw)
  451. >>> image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
  452. >>> model = ViTForMaskedImageModeling.from_pretrained("google/vit-base-patch16-224-in21k")
  453. >>> num_patches = (model.config.image_size // model.config.patch_size) ** 2
  454. >>> pixel_values = image_processor(images=image, return_tensors="pt").pixel_values
  455. >>> # create random boolean mask of shape (batch_size, num_patches)
  456. >>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()
  457. >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
  458. >>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction
  459. >>> list(reconstructed_pixel_values.shape)
  460. [1, 3, 224, 224]
  461. ```"""
  462. if bool_masked_pos is not None and (self.config.patch_size != self.config.encoder_stride):
  463. raise ValueError(
  464. "When `bool_masked_pos` is provided, `patch_size` must be equal to `encoder_stride` to ensure that "
  465. "the reconstructed image has the same dimensions as the input. "
  466. f"Got `patch_size` = {self.config.patch_size} and `encoder_stride` = {self.config.encoder_stride}."
  467. )
  468. outputs: BaseModelOutputWithPooling = self.vit(
  469. pixel_values,
  470. bool_masked_pos=bool_masked_pos,
  471. head_mask=head_mask,
  472. interpolate_pos_encoding=interpolate_pos_encoding,
  473. **kwargs,
  474. )
  475. sequence_output = outputs.last_hidden_state
  476. # Reshape to (batch_size, num_channels, height, width)
  477. sequence_output = sequence_output[:, 1:]
  478. batch_size, sequence_length, num_channels = sequence_output.shape
  479. height = width = math.floor(sequence_length**0.5)
  480. sequence_output = sequence_output.permute(0, 2, 1).reshape(batch_size, num_channels, height, width)
  481. # Reconstruct pixel values
  482. reconstructed_pixel_values = self.decoder(sequence_output)
  483. masked_im_loss = None
  484. if bool_masked_pos is not None:
  485. size = self.config.image_size // self.config.patch_size
  486. bool_masked_pos = bool_masked_pos.reshape(-1, size, size)
  487. mask = (
  488. bool_masked_pos.repeat_interleave(self.config.patch_size, 1)
  489. .repeat_interleave(self.config.patch_size, 2)
  490. .unsqueeze(1)
  491. .contiguous()
  492. )
  493. reconstruction_loss = nn.functional.l1_loss(pixel_values, reconstructed_pixel_values, reduction="none")
  494. masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels
  495. return MaskedImageModelingOutput(
  496. loss=masked_im_loss,
  497. reconstruction=reconstructed_pixel_values,
  498. hidden_states=outputs.hidden_states,
  499. attentions=outputs.attentions,
  500. )
  501. @auto_docstring(
  502. custom_intro="""
  503. ViT Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
  504. the [CLS] token) e.g. for ImageNet.
  505. <Tip>
  506. Note that it's possible to fine-tune ViT on higher resolution images than the ones it has been trained on, by
  507. setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
  508. position embeddings to the higher resolution.
  509. </Tip>
  510. """
  511. )
  512. class ViTForImageClassification(ViTPreTrainedModel):
  513. def __init__(self, config: ViTConfig):
  514. super().__init__(config)
  515. self.num_labels = config.num_labels
  516. self.vit = ViTModel(config, add_pooling_layer=False)
  517. # Classifier head
  518. self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
  519. # Initialize weights and apply final processing
  520. self.post_init()
  521. @can_return_tuple
  522. @auto_docstring
  523. def forward(
  524. self,
  525. pixel_values: Optional[torch.Tensor] = None,
  526. head_mask: Optional[torch.Tensor] = None,
  527. labels: Optional[torch.Tensor] = None,
  528. interpolate_pos_encoding: Optional[bool] = None,
  529. **kwargs: Unpack[TransformersKwargs],
  530. ) -> ImageClassifierOutput:
  531. r"""
  532. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  533. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  534. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  535. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  536. """
  537. outputs: BaseModelOutputWithPooling = self.vit(
  538. pixel_values,
  539. head_mask=head_mask,
  540. interpolate_pos_encoding=interpolate_pos_encoding,
  541. **kwargs,
  542. )
  543. sequence_output = outputs.last_hidden_state
  544. pooled_output = sequence_output[:, 0, :]
  545. logits = self.classifier(pooled_output)
  546. loss = None
  547. if labels is not None:
  548. loss = self.loss_function(labels, logits, self.config, **kwargs)
  549. return ImageClassifierOutput(
  550. loss=loss,
  551. logits=logits,
  552. hidden_states=outputs.hidden_states,
  553. attentions=outputs.attentions,
  554. )
  555. __all__ = ["ViTForImageClassification", "ViTForMaskedImageModeling", "ViTModel", "ViTPreTrainedModel"]