modeling_dinov2.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684
  1. # coding=utf-8
  2. # Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch DINOv2 model."""
  16. import collections.abc
  17. from typing import Callable, Optional, Union
  18. import torch
  19. from torch import nn
  20. from ...activations import ACT2FN
  21. from ...modeling_layers import GradientCheckpointingLayer
  22. from ...modeling_outputs import BackboneOutput, BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
  23. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  24. from ...processing_utils import Unpack
  25. from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
  26. from ...utils import TransformersKwargs, auto_docstring, logging, torch_int
  27. from ...utils.backbone_utils import BackboneMixin
  28. from ...utils.generic import can_return_tuple, check_model_inputs
  29. from .configuration_dinov2 import Dinov2Config
  30. logger = logging.get_logger(__name__)
  31. class Dinov2Embeddings(nn.Module):
  32. """
  33. Construct the CLS token, mask token, position and patch embeddings.
  34. """
  35. def __init__(self, config: Dinov2Config) -> None:
  36. super().__init__()
  37. self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
  38. if config.use_mask_token:
  39. self.mask_token = nn.Parameter(torch.zeros(1, config.hidden_size))
  40. self.patch_embeddings = Dinov2PatchEmbeddings(config)
  41. num_patches = self.patch_embeddings.num_patches
  42. self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size))
  43. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  44. self.patch_size = config.patch_size
  45. self.use_mask_token = config.use_mask_token
  46. self.config = config
  47. def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
  48. """
  49. This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
  50. images. This method is also adapted to support torch.jit tracing and interpolation at torch.float32 precision.
  51. Adapted from:
  52. - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
  53. - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
  54. """
  55. num_patches = embeddings.shape[1] - 1
  56. num_positions = self.position_embeddings.shape[1] - 1
  57. # always interpolate when tracing to ensure the exported model works for dynamic input shapes
  58. if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
  59. return self.position_embeddings
  60. class_pos_embed = self.position_embeddings[:, :1]
  61. patch_pos_embed = self.position_embeddings[:, 1:]
  62. dim = embeddings.shape[-1]
  63. new_height = height // self.patch_size
  64. new_width = width // self.patch_size
  65. sqrt_num_positions = torch_int(num_positions**0.5)
  66. patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
  67. patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
  68. target_dtype = patch_pos_embed.dtype
  69. patch_pos_embed = nn.functional.interpolate(
  70. patch_pos_embed.to(torch.float32),
  71. size=(new_height, new_width),
  72. mode="bicubic",
  73. align_corners=False,
  74. ).to(dtype=target_dtype)
  75. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  76. return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
  77. def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Tensor] = None) -> torch.Tensor:
  78. batch_size, _, height, width = pixel_values.shape
  79. target_dtype = self.patch_embeddings.projection.weight.dtype
  80. embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype))
  81. if bool_masked_pos is not None and self.use_mask_token:
  82. embeddings = torch.where(
  83. bool_masked_pos.unsqueeze(-1), self.mask_token.to(embeddings.dtype).unsqueeze(0), embeddings
  84. )
  85. # add the [CLS] token to the embedded patch tokens
  86. cls_tokens = self.cls_token.expand(batch_size, -1, -1)
  87. embeddings = torch.cat((cls_tokens, embeddings), dim=1)
  88. # add positional encoding to each token
  89. embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
  90. embeddings = self.dropout(embeddings)
  91. return embeddings
  92. class Dinov2PatchEmbeddings(nn.Module):
  93. """
  94. This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
  95. `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
  96. Transformer.
  97. """
  98. def __init__(self, config):
  99. super().__init__()
  100. image_size, patch_size = config.image_size, config.patch_size
  101. num_channels, hidden_size = config.num_channels, config.hidden_size
  102. image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
  103. patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
  104. num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
  105. self.image_size = image_size
  106. self.patch_size = patch_size
  107. self.num_channels = num_channels
  108. self.num_patches = num_patches
  109. self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
  110. def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
  111. num_channels = pixel_values.shape[1]
  112. if num_channels != self.num_channels:
  113. raise ValueError(
  114. "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
  115. f" Expected {self.num_channels} but got {num_channels}."
  116. )
  117. embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
  118. return embeddings
  119. # Copied from transformers.models.vit.modeling_vit.eager_attention_forward
  120. def eager_attention_forward(
  121. module: nn.Module,
  122. query: torch.Tensor,
  123. key: torch.Tensor,
  124. value: torch.Tensor,
  125. attention_mask: Optional[torch.Tensor],
  126. scaling: float,
  127. dropout: float = 0.0,
  128. **kwargs,
  129. ):
  130. # Take the dot product between "query" and "key" to get the raw attention scores.
  131. attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
  132. # Normalize the attention scores to probabilities.
  133. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  134. # This is actually dropping out entire tokens to attend to, which might
  135. # seem a bit unusual, but is taken from the original Transformer paper.
  136. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  137. # Mask heads if we want to
  138. if attention_mask is not None:
  139. attn_weights = attn_weights * attention_mask
  140. attn_output = torch.matmul(attn_weights, value)
  141. attn_output = attn_output.transpose(1, 2).contiguous()
  142. return attn_output, attn_weights
  143. # Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->Dinov2
  144. class Dinov2SelfAttention(nn.Module):
  145. def __init__(self, config: Dinov2Config):
  146. super().__init__()
  147. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  148. raise ValueError(
  149. f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
  150. f"heads {config.num_attention_heads}."
  151. )
  152. self.config = config
  153. self.num_attention_heads = config.num_attention_heads
  154. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  155. self.all_head_size = self.num_attention_heads * self.attention_head_size
  156. self.dropout_prob = config.attention_probs_dropout_prob
  157. self.scaling = self.attention_head_size**-0.5
  158. self.is_causal = False
  159. self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
  160. self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
  161. self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
  162. def forward(
  163. self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None
  164. ) -> tuple[torch.Tensor, torch.Tensor]:
  165. batch_size = hidden_states.shape[0]
  166. new_shape = batch_size, -1, self.num_attention_heads, self.attention_head_size
  167. key_layer = self.key(hidden_states).view(*new_shape).transpose(1, 2)
  168. value_layer = self.value(hidden_states).view(*new_shape).transpose(1, 2)
  169. query_layer = self.query(hidden_states).view(*new_shape).transpose(1, 2)
  170. attention_interface: Callable = eager_attention_forward
  171. if self.config._attn_implementation != "eager":
  172. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  173. context_layer, attention_probs = attention_interface(
  174. self,
  175. query_layer,
  176. key_layer,
  177. value_layer,
  178. head_mask,
  179. is_causal=self.is_causal,
  180. scaling=self.scaling,
  181. dropout=0.0 if not self.training else self.dropout_prob,
  182. )
  183. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  184. context_layer = context_layer.reshape(new_context_layer_shape)
  185. return context_layer, attention_probs
  186. # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Dinov2
  187. class Dinov2SelfOutput(nn.Module):
  188. """
  189. The residual connection is defined in Dinov2Layer instead of here (as is the case with other models), due to the
  190. layernorm applied before each block.
  191. """
  192. def __init__(self, config: Dinov2Config):
  193. super().__init__()
  194. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  195. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  196. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  197. hidden_states = self.dense(hidden_states)
  198. hidden_states = self.dropout(hidden_states)
  199. return hidden_states
  200. # Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->Dinov2
  201. class Dinov2Attention(nn.Module):
  202. def __init__(self, config: Dinov2Config):
  203. super().__init__()
  204. self.attention = Dinov2SelfAttention(config)
  205. self.output = Dinov2SelfOutput(config)
  206. self.pruned_heads = set()
  207. def prune_heads(self, heads: set[int]):
  208. if len(heads) == 0:
  209. return
  210. heads, index = find_pruneable_heads_and_indices(
  211. heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
  212. )
  213. # Prune linear layers
  214. self.attention.query = prune_linear_layer(self.attention.query, index)
  215. self.attention.key = prune_linear_layer(self.attention.key, index)
  216. self.attention.value = prune_linear_layer(self.attention.value, index)
  217. self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
  218. # Update hyper params and store pruned heads
  219. self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
  220. self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
  221. self.pruned_heads = self.pruned_heads.union(heads)
  222. def forward(self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
  223. self_attn_output, _ = self.attention(hidden_states, head_mask)
  224. output = self.output(self_attn_output, hidden_states)
  225. return output
  226. class Dinov2LayerScale(nn.Module):
  227. def __init__(self, config) -> None:
  228. super().__init__()
  229. self.lambda1 = nn.Parameter(config.layerscale_value * torch.ones(config.hidden_size))
  230. def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
  231. return hidden_state * self.lambda1
  232. # Copied from transformers.models.beit.modeling_beit.drop_path
  233. def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
  234. """
  235. Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  236. Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
  237. however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
  238. See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
  239. layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
  240. argument.
  241. """
  242. if drop_prob == 0.0 or not training:
  243. return input
  244. keep_prob = 1 - drop_prob
  245. shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
  246. random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
  247. random_tensor.floor_() # binarize
  248. output = input.div(keep_prob) * random_tensor
  249. return output
  250. # Copied from transformers.models.beit.modeling_beit.BeitDropPath
  251. class Dinov2DropPath(nn.Module):
  252. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
  253. def __init__(self, drop_prob: Optional[float] = None) -> None:
  254. super().__init__()
  255. self.drop_prob = drop_prob
  256. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  257. return drop_path(hidden_states, self.drop_prob, self.training)
  258. def extra_repr(self) -> str:
  259. return f"p={self.drop_prob}"
  260. class Dinov2MLP(nn.Module):
  261. def __init__(self, config) -> None:
  262. super().__init__()
  263. in_features = out_features = config.hidden_size
  264. hidden_features = int(config.hidden_size * config.mlp_ratio)
  265. self.fc1 = nn.Linear(in_features, hidden_features, bias=True)
  266. if isinstance(config.hidden_act, str):
  267. self.activation = ACT2FN[config.hidden_act]
  268. else:
  269. self.activation = config.hidden_act
  270. self.fc2 = nn.Linear(hidden_features, out_features, bias=True)
  271. def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
  272. hidden_state = self.fc1(hidden_state)
  273. hidden_state = self.activation(hidden_state)
  274. hidden_state = self.fc2(hidden_state)
  275. return hidden_state
  276. class Dinov2SwiGLUFFN(nn.Module):
  277. def __init__(self, config) -> None:
  278. super().__init__()
  279. in_features = out_features = config.hidden_size
  280. hidden_features = int(config.hidden_size * config.mlp_ratio)
  281. hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
  282. self.weights_in = nn.Linear(in_features, 2 * hidden_features, bias=True)
  283. self.weights_out = nn.Linear(hidden_features, out_features, bias=True)
  284. def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
  285. hidden_state = self.weights_in(hidden_state)
  286. x1, x2 = hidden_state.chunk(2, dim=-1)
  287. hidden = nn.functional.silu(x1) * x2
  288. return self.weights_out(hidden)
  289. class Dinov2Layer(GradientCheckpointingLayer):
  290. """This corresponds to the Block class in the original implementation."""
  291. def __init__(self, config: Dinov2Config) -> None:
  292. super().__init__()
  293. self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  294. self.attention = Dinov2Attention(config)
  295. self.layer_scale1 = Dinov2LayerScale(config)
  296. self.drop_path = Dinov2DropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
  297. self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  298. if config.use_swiglu_ffn:
  299. self.mlp = Dinov2SwiGLUFFN(config)
  300. else:
  301. self.mlp = Dinov2MLP(config)
  302. self.layer_scale2 = Dinov2LayerScale(config)
  303. def forward(
  304. self,
  305. hidden_states: torch.Tensor,
  306. head_mask: Optional[torch.Tensor] = None,
  307. ) -> torch.Tensor:
  308. hidden_states_norm = self.norm1(hidden_states)
  309. self_attention_output = self.attention(hidden_states_norm, head_mask)
  310. self_attention_output = self.layer_scale1(self_attention_output)
  311. # first residual connection
  312. hidden_states = self.drop_path(self_attention_output) + hidden_states
  313. # in Dinov2, layernorm is also applied after self-attention
  314. layer_output = self.norm2(hidden_states)
  315. layer_output = self.mlp(layer_output)
  316. layer_output = self.layer_scale2(layer_output)
  317. # second residual connection
  318. layer_output = self.drop_path(layer_output) + hidden_states
  319. return layer_output
  320. class Dinov2Encoder(nn.Module):
  321. def __init__(self, config: Dinov2Config):
  322. super().__init__()
  323. self.config = config
  324. self.layer = nn.ModuleList([Dinov2Layer(config) for _ in range(config.num_hidden_layers)])
  325. self.gradient_checkpointing = False
  326. def forward(
  327. self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None, output_hidden_states: bool = False
  328. ) -> BaseModelOutput:
  329. all_hidden_states = [hidden_states] if output_hidden_states else None
  330. for i, layer_module in enumerate(self.layer):
  331. layer_head_mask = head_mask[i] if head_mask is not None else None
  332. hidden_states = layer_module(hidden_states, layer_head_mask)
  333. if all_hidden_states:
  334. all_hidden_states.append(hidden_states)
  335. return BaseModelOutput(
  336. last_hidden_state=hidden_states,
  337. hidden_states=tuple(all_hidden_states) if all_hidden_states else None,
  338. )
  339. @auto_docstring
  340. class Dinov2PreTrainedModel(PreTrainedModel):
  341. config: Dinov2Config
  342. base_model_prefix = "dinov2"
  343. main_input_name = "pixel_values"
  344. supports_gradient_checkpointing = True
  345. _no_split_modules = ["Dinov2Layer"]
  346. _supports_sdpa = True
  347. _supports_flash_attn = True
  348. _supports_flex_attn = True
  349. _supports_attention_backend = True
  350. _can_record_outputs = {
  351. "attentions": Dinov2SelfAttention,
  352. }
  353. def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
  354. """Initialize the weights"""
  355. if isinstance(module, (nn.Linear, nn.Conv2d)):
  356. # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
  357. # `trunc_normal_cpu` not implemented in `half` issues
  358. module.weight.data = nn.init.trunc_normal_(
  359. module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
  360. ).to(module.weight.dtype)
  361. if module.bias is not None:
  362. module.bias.data.zero_()
  363. elif isinstance(module, nn.LayerNorm):
  364. module.bias.data.zero_()
  365. module.weight.data.fill_(1.0)
  366. elif isinstance(module, Dinov2Embeddings):
  367. module.position_embeddings.data = nn.init.trunc_normal_(
  368. module.position_embeddings.data.to(torch.float32),
  369. mean=0.0,
  370. std=self.config.initializer_range,
  371. ).to(module.position_embeddings.dtype)
  372. module.cls_token.data = nn.init.trunc_normal_(
  373. module.cls_token.data.to(torch.float32),
  374. mean=0.0,
  375. std=self.config.initializer_range,
  376. ).to(module.cls_token.dtype)
  377. if self.config.use_mask_token:
  378. module.mask_token.data.zero_()
  379. elif isinstance(module, Dinov2LayerScale):
  380. module.lambda1.data.fill_(self.config.layerscale_value)
  381. @auto_docstring
  382. class Dinov2Model(Dinov2PreTrainedModel):
  383. def __init__(self, config: Dinov2Config):
  384. super().__init__(config)
  385. self.config = config
  386. self.embeddings = Dinov2Embeddings(config)
  387. self.encoder = Dinov2Encoder(config)
  388. self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  389. # Initialize weights and apply final processing
  390. self.post_init()
  391. def get_input_embeddings(self) -> Dinov2PatchEmbeddings:
  392. return self.embeddings.patch_embeddings
  393. def _prune_heads(self, heads_to_prune: dict[int, list[int]]) -> None:
  394. """
  395. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  396. class PreTrainedModel
  397. """
  398. for layer, heads in heads_to_prune.items():
  399. self.encoder.layer[layer].attention.prune_heads(heads)
  400. @check_model_inputs(tie_last_hidden_states=False)
  401. @auto_docstring
  402. def forward(
  403. self,
  404. pixel_values: Optional[torch.Tensor] = None,
  405. bool_masked_pos: Optional[torch.Tensor] = None,
  406. head_mask: Optional[torch.Tensor] = None,
  407. output_hidden_states: Optional[bool] = None,
  408. **kwargs,
  409. ) -> BaseModelOutputWithPooling:
  410. r"""
  411. bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, sequence_length)`):
  412. Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). Only relevant for
  413. pre-training.
  414. """
  415. if output_hidden_states is None:
  416. output_hidden_states = self.config.output_hidden_states
  417. if pixel_values is None:
  418. raise ValueError("You have to specify pixel_values")
  419. # Prepare head mask if needed
  420. # 1.0 in head_mask indicate we keep the head
  421. # attention_probs has shape bsz x n_heads x N x N
  422. # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
  423. # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
  424. head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
  425. embedding_output = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
  426. encoder_outputs: BaseModelOutput = self.encoder(
  427. embedding_output, head_mask=head_mask, output_hidden_states=output_hidden_states
  428. )
  429. sequence_output = encoder_outputs.last_hidden_state
  430. sequence_output = self.layernorm(sequence_output)
  431. pooled_output = sequence_output[:, 0, :]
  432. return BaseModelOutputWithPooling(
  433. last_hidden_state=sequence_output,
  434. pooler_output=pooled_output,
  435. hidden_states=encoder_outputs.hidden_states,
  436. )
  437. @auto_docstring(
  438. custom_intro="""
  439. Dinov2 Model transformer with an image classification head on top (a linear layer on top of the final hidden state
  440. of the [CLS] token) e.g. for ImageNet.
  441. """
  442. )
  443. class Dinov2ForImageClassification(Dinov2PreTrainedModel):
  444. def __init__(self, config: Dinov2Config) -> None:
  445. super().__init__(config)
  446. self.num_labels = config.num_labels
  447. self.dinov2 = Dinov2Model(config)
  448. # Classifier head
  449. self.classifier = (
  450. nn.Linear(config.hidden_size * 2, config.num_labels) if config.num_labels > 0 else nn.Identity()
  451. )
  452. # Initialize weights and apply final processing
  453. self.post_init()
  454. @can_return_tuple
  455. @auto_docstring
  456. def forward(
  457. self,
  458. pixel_values: Optional[torch.Tensor] = None,
  459. head_mask: Optional[torch.Tensor] = None,
  460. labels: Optional[torch.Tensor] = None,
  461. **kwargs: Unpack[TransformersKwargs],
  462. ) -> ImageClassifierOutput:
  463. r"""
  464. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  465. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  466. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  467. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  468. """
  469. outputs: BaseModelOutputWithPooling = self.dinov2(pixel_values, head_mask=head_mask, **kwargs)
  470. sequence_output = outputs.last_hidden_state # batch_size, sequence_length, hidden_size
  471. cls_token = sequence_output[:, 0]
  472. patch_tokens = sequence_output[:, 1:]
  473. linear_input = torch.cat([cls_token, patch_tokens.mean(dim=1)], dim=1)
  474. logits = self.classifier(linear_input)
  475. loss = None
  476. if labels is not None:
  477. loss = self.loss_function(labels, logits, self.config, **kwargs)
  478. return ImageClassifierOutput(
  479. loss=loss,
  480. logits=logits,
  481. hidden_states=outputs.hidden_states,
  482. attentions=outputs.attentions,
  483. )
  484. @auto_docstring(
  485. custom_intro="""
  486. Dinov2 backbone, to be used with frameworks like DETR and MaskFormer.
  487. """
  488. )
  489. class Dinov2Backbone(Dinov2PreTrainedModel, BackboneMixin):
  490. def __init__(self, config):
  491. super().__init__(config)
  492. super()._init_backbone(config)
  493. self.num_features = [config.hidden_size for _ in range(config.num_hidden_layers + 1)]
  494. self.embeddings = Dinov2Embeddings(config)
  495. self.encoder = Dinov2Encoder(config)
  496. self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  497. # Initialize weights and apply final processing
  498. self.post_init()
  499. def get_input_embeddings(self) -> Dinov2PatchEmbeddings:
  500. return self.embeddings.patch_embeddings
  501. @check_model_inputs()
  502. @auto_docstring
  503. def forward(
  504. self, pixel_values: torch.Tensor, output_hidden_states: Optional[bool] = None, **kwargs
  505. ) -> BackboneOutput:
  506. r"""
  507. Examples:
  508. ```python
  509. >>> from transformers import AutoImageProcessor, AutoBackbone
  510. >>> import torch
  511. >>> from PIL import Image
  512. >>> import requests
  513. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  514. >>> image = Image.open(requests.get(url, stream=True).raw)
  515. >>> processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
  516. >>> model = AutoBackbone.from_pretrained(
  517. ... "facebook/dinov2-base", out_features=["stage2", "stage5", "stage8", "stage11"]
  518. ... )
  519. >>> inputs = processor(image, return_tensors="pt")
  520. >>> outputs = model(**inputs)
  521. >>> feature_maps = outputs.feature_maps
  522. >>> list(feature_maps[-1].shape)
  523. [1, 768, 16, 16]
  524. ```"""
  525. if output_hidden_states is None:
  526. output_hidden_states = self.config.output_hidden_states
  527. embedding_output = self.embeddings(pixel_values)
  528. output: BaseModelOutput = self.encoder(embedding_output, output_hidden_states=True)
  529. hidden_states = output.hidden_states
  530. feature_maps = []
  531. for stage, hidden_state in zip(self.stage_names, hidden_states):
  532. if stage in self.out_features:
  533. if self.config.apply_layernorm:
  534. hidden_state = self.layernorm(hidden_state)
  535. if self.config.reshape_hidden_states:
  536. hidden_state = hidden_state[:, 1:]
  537. # this was actually a bug in the original implementation that we copied here,
  538. # cause normally the order is height, width
  539. batch_size, _, height, width = pixel_values.shape
  540. patch_size = self.config.patch_size
  541. hidden_state = hidden_state.reshape(batch_size, height // patch_size, width // patch_size, -1)
  542. hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
  543. feature_maps.append(hidden_state)
  544. return BackboneOutput(
  545. feature_maps=tuple(feature_maps),
  546. hidden_states=hidden_states if output_hidden_states else None,
  547. )
  548. __all__ = ["Dinov2ForImageClassification", "Dinov2Model", "Dinov2PreTrainedModel", "Dinov2Backbone"]