modeling_yolos.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704
  1. # coding=utf-8
  2. # Copyright 2022 School of EIC, Huazhong University of Science & Technology and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch YOLOS model."""
  16. import collections.abc
  17. from dataclasses import dataclass
  18. from typing import Callable, Optional, Union
  19. import torch
  20. from torch import nn
  21. from ...activations import ACT2FN
  22. from ...modeling_layers import GradientCheckpointingLayer
  23. from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
  24. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  25. from ...processing_utils import Unpack
  26. from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
  27. from ...utils import ModelOutput, TransformersKwargs, auto_docstring, logging
  28. from ...utils.generic import can_return_tuple, check_model_inputs
  29. from .configuration_yolos import YolosConfig
  30. logger = logging.get_logger(__name__)
  31. @dataclass
  32. @auto_docstring(
  33. custom_intro="""
  34. Output type of [`YolosForObjectDetection`].
  35. """
  36. )
  37. class YolosObjectDetectionOutput(ModelOutput):
  38. r"""
  39. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)):
  40. Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a
  41. bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized
  42. scale-invariant IoU loss.
  43. loss_dict (`Dict`, *optional*):
  44. A dictionary containing the individual losses. Useful for logging.
  45. logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes + 1)`):
  46. Classification logits (including no-object) for all queries.
  47. pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
  48. Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
  49. values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
  50. possible padding). You can use [`~YolosImageProcessor.post_process`] to retrieve the unnormalized bounding
  51. boxes.
  52. auxiliary_outputs (`list[Dict]`, *optional*):
  53. Optional, only returned when auxiliary losses are activated (i.e. `config.auxiliary_loss` is set to `True`)
  54. and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and
  55. `pred_boxes`) for each decoder layer.
  56. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  57. Sequence of hidden-states at the output of the last layer of the decoder of the model.
  58. """
  59. loss: Optional[torch.FloatTensor] = None
  60. loss_dict: Optional[dict] = None
  61. logits: Optional[torch.FloatTensor] = None
  62. pred_boxes: Optional[torch.FloatTensor] = None
  63. auxiliary_outputs: Optional[list[dict]] = None
  64. last_hidden_state: Optional[torch.FloatTensor] = None
  65. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  66. attentions: Optional[tuple[torch.FloatTensor]] = None
  67. class YolosEmbeddings(nn.Module):
  68. """
  69. Construct the CLS token, detection tokens, position and patch embeddings.
  70. """
  71. def __init__(self, config: YolosConfig) -> None:
  72. super().__init__()
  73. self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
  74. self.detection_tokens = nn.Parameter(torch.zeros(1, config.num_detection_tokens, config.hidden_size))
  75. self.patch_embeddings = YolosPatchEmbeddings(config)
  76. num_patches = self.patch_embeddings.num_patches
  77. self.position_embeddings = nn.Parameter(
  78. torch.zeros(1, num_patches + config.num_detection_tokens + 1, config.hidden_size)
  79. )
  80. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  81. self.interpolation = InterpolateInitialPositionEmbeddings(config)
  82. self.config = config
  83. def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
  84. batch_size, num_channels, height, width = pixel_values.shape
  85. embeddings = self.patch_embeddings(pixel_values)
  86. batch_size, seq_len, _ = embeddings.size()
  87. # add the [CLS] and detection tokens to the embedded patch tokens
  88. cls_tokens = self.cls_token.expand(batch_size, -1, -1)
  89. detection_tokens = self.detection_tokens.expand(batch_size, -1, -1)
  90. embeddings = torch.cat((cls_tokens, embeddings, detection_tokens), dim=1)
  91. # add positional encoding to each token
  92. # this might require interpolation of the existing position embeddings
  93. position_embeddings = self.interpolation(self.position_embeddings, (height, width))
  94. embeddings = embeddings + position_embeddings
  95. embeddings = self.dropout(embeddings)
  96. return embeddings
  97. class InterpolateInitialPositionEmbeddings(nn.Module):
  98. def __init__(self, config) -> None:
  99. super().__init__()
  100. self.config = config
  101. def forward(self, pos_embed, img_size=(800, 1344)) -> torch.Tensor:
  102. cls_pos_embed = pos_embed[:, 0, :]
  103. cls_pos_embed = cls_pos_embed[:, None]
  104. det_pos_embed = pos_embed[:, -self.config.num_detection_tokens :, :]
  105. patch_pos_embed = pos_embed[:, 1 : -self.config.num_detection_tokens, :]
  106. patch_pos_embed = patch_pos_embed.transpose(1, 2)
  107. batch_size, hidden_size, seq_len = patch_pos_embed.shape
  108. patch_height, patch_width = (
  109. self.config.image_size[0] // self.config.patch_size,
  110. self.config.image_size[1] // self.config.patch_size,
  111. )
  112. patch_pos_embed = patch_pos_embed.view(batch_size, hidden_size, patch_height, patch_width)
  113. height, width = img_size
  114. new_patch_height, new_patch_width = height // self.config.patch_size, width // self.config.patch_size
  115. patch_pos_embed = nn.functional.interpolate(
  116. patch_pos_embed, size=(new_patch_height, new_patch_width), mode="bicubic", align_corners=False
  117. )
  118. patch_pos_embed = patch_pos_embed.flatten(2).transpose(1, 2)
  119. scale_pos_embed = torch.cat((cls_pos_embed, patch_pos_embed, det_pos_embed), dim=1)
  120. return scale_pos_embed
  121. class InterpolateMidPositionEmbeddings(nn.Module):
  122. def __init__(self, config) -> None:
  123. super().__init__()
  124. self.config = config
  125. def forward(self, pos_embed, img_size=(800, 1344)) -> torch.Tensor:
  126. cls_pos_embed = pos_embed[:, :, 0, :]
  127. cls_pos_embed = cls_pos_embed[:, None]
  128. det_pos_embed = pos_embed[:, :, -self.config.num_detection_tokens :, :]
  129. patch_pos_embed = pos_embed[:, :, 1 : -self.config.num_detection_tokens, :]
  130. patch_pos_embed = patch_pos_embed.transpose(2, 3)
  131. depth, batch_size, hidden_size, seq_len = patch_pos_embed.shape
  132. patch_height, patch_width = (
  133. self.config.image_size[0] // self.config.patch_size,
  134. self.config.image_size[1] // self.config.patch_size,
  135. )
  136. patch_pos_embed = patch_pos_embed.view(depth * batch_size, hidden_size, patch_height, patch_width)
  137. height, width = img_size
  138. new_patch_height, new_patch_width = height // self.config.patch_size, width // self.config.patch_size
  139. patch_pos_embed = nn.functional.interpolate(
  140. patch_pos_embed, size=(new_patch_height, new_patch_width), mode="bicubic", align_corners=False
  141. )
  142. patch_pos_embed = (
  143. patch_pos_embed.flatten(2)
  144. .transpose(1, 2)
  145. .contiguous()
  146. .view(depth, batch_size, new_patch_height * new_patch_width, hidden_size)
  147. )
  148. scale_pos_embed = torch.cat((cls_pos_embed, patch_pos_embed, det_pos_embed), dim=2)
  149. return scale_pos_embed
  150. class YolosPatchEmbeddings(nn.Module):
  151. """
  152. This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
  153. `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
  154. Transformer.
  155. """
  156. def __init__(self, config):
  157. super().__init__()
  158. image_size, patch_size = config.image_size, config.patch_size
  159. num_channels, hidden_size = config.num_channels, config.hidden_size
  160. image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
  161. patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
  162. num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
  163. self.image_size = image_size
  164. self.patch_size = patch_size
  165. self.num_channels = num_channels
  166. self.num_patches = num_patches
  167. self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
  168. def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
  169. batch_size, num_channels, height, width = pixel_values.shape
  170. if num_channels != self.num_channels:
  171. raise ValueError(
  172. "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
  173. )
  174. embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
  175. return embeddings
  176. # Copied from transformers.models.vit.modeling_vit.eager_attention_forward
  177. def eager_attention_forward(
  178. module: nn.Module,
  179. query: torch.Tensor,
  180. key: torch.Tensor,
  181. value: torch.Tensor,
  182. attention_mask: Optional[torch.Tensor],
  183. scaling: float,
  184. dropout: float = 0.0,
  185. **kwargs,
  186. ):
  187. # Take the dot product between "query" and "key" to get the raw attention scores.
  188. attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
  189. # Normalize the attention scores to probabilities.
  190. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  191. # This is actually dropping out entire tokens to attend to, which might
  192. # seem a bit unusual, but is taken from the original Transformer paper.
  193. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  194. # Mask heads if we want to
  195. if attention_mask is not None:
  196. attn_weights = attn_weights * attention_mask
  197. attn_output = torch.matmul(attn_weights, value)
  198. attn_output = attn_output.transpose(1, 2).contiguous()
  199. return attn_output, attn_weights
  200. # Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->Yolos
  201. class YolosSelfAttention(nn.Module):
  202. def __init__(self, config: YolosConfig):
  203. super().__init__()
  204. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  205. raise ValueError(
  206. f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
  207. f"heads {config.num_attention_heads}."
  208. )
  209. self.config = config
  210. self.num_attention_heads = config.num_attention_heads
  211. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  212. self.all_head_size = self.num_attention_heads * self.attention_head_size
  213. self.dropout_prob = config.attention_probs_dropout_prob
  214. self.scaling = self.attention_head_size**-0.5
  215. self.is_causal = False
  216. self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
  217. self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
  218. self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
  219. def forward(
  220. self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None
  221. ) -> tuple[torch.Tensor, torch.Tensor]:
  222. batch_size = hidden_states.shape[0]
  223. new_shape = batch_size, -1, self.num_attention_heads, self.attention_head_size
  224. key_layer = self.key(hidden_states).view(*new_shape).transpose(1, 2)
  225. value_layer = self.value(hidden_states).view(*new_shape).transpose(1, 2)
  226. query_layer = self.query(hidden_states).view(*new_shape).transpose(1, 2)
  227. attention_interface: Callable = eager_attention_forward
  228. if self.config._attn_implementation != "eager":
  229. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  230. context_layer, attention_probs = attention_interface(
  231. self,
  232. query_layer,
  233. key_layer,
  234. value_layer,
  235. head_mask,
  236. is_causal=self.is_causal,
  237. scaling=self.scaling,
  238. dropout=0.0 if not self.training else self.dropout_prob,
  239. )
  240. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  241. context_layer = context_layer.reshape(new_context_layer_shape)
  242. return context_layer, attention_probs
  243. # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Yolos
  244. class YolosSelfOutput(nn.Module):
  245. """
  246. The residual connection is defined in YolosLayer instead of here (as is the case with other models), due to the
  247. layernorm applied before each block.
  248. """
  249. def __init__(self, config: YolosConfig):
  250. super().__init__()
  251. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  252. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  253. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  254. hidden_states = self.dense(hidden_states)
  255. hidden_states = self.dropout(hidden_states)
  256. return hidden_states
  257. # Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->Yolos
  258. class YolosAttention(nn.Module):
  259. def __init__(self, config: YolosConfig):
  260. super().__init__()
  261. self.attention = YolosSelfAttention(config)
  262. self.output = YolosSelfOutput(config)
  263. self.pruned_heads = set()
  264. def prune_heads(self, heads: set[int]):
  265. if len(heads) == 0:
  266. return
  267. heads, index = find_pruneable_heads_and_indices(
  268. heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
  269. )
  270. # Prune linear layers
  271. self.attention.query = prune_linear_layer(self.attention.query, index)
  272. self.attention.key = prune_linear_layer(self.attention.key, index)
  273. self.attention.value = prune_linear_layer(self.attention.value, index)
  274. self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
  275. # Update hyper params and store pruned heads
  276. self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
  277. self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
  278. self.pruned_heads = self.pruned_heads.union(heads)
  279. def forward(self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
  280. self_attn_output, _ = self.attention(hidden_states, head_mask)
  281. output = self.output(self_attn_output, hidden_states)
  282. return output
  283. # Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->Yolos
  284. class YolosIntermediate(nn.Module):
  285. def __init__(self, config: YolosConfig):
  286. super().__init__()
  287. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  288. if isinstance(config.hidden_act, str):
  289. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  290. else:
  291. self.intermediate_act_fn = config.hidden_act
  292. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  293. hidden_states = self.dense(hidden_states)
  294. hidden_states = self.intermediate_act_fn(hidden_states)
  295. return hidden_states
  296. # Copied from transformers.models.vit.modeling_vit.ViTOutput with ViT->Yolos
  297. class YolosOutput(nn.Module):
  298. def __init__(self, config: YolosConfig):
  299. super().__init__()
  300. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  301. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  302. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  303. hidden_states = self.dense(hidden_states)
  304. hidden_states = self.dropout(hidden_states)
  305. hidden_states = hidden_states + input_tensor
  306. return hidden_states
  307. # Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->Yolos,VIT->YOLOS
  308. class YolosLayer(GradientCheckpointingLayer):
  309. """This corresponds to the Block class in the timm implementation."""
  310. def __init__(self, config: YolosConfig):
  311. super().__init__()
  312. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  313. self.seq_len_dim = 1
  314. self.attention = YolosAttention(config)
  315. self.intermediate = YolosIntermediate(config)
  316. self.output = YolosOutput(config)
  317. self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  318. self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  319. def forward(self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
  320. hidden_states_norm = self.layernorm_before(hidden_states)
  321. attention_output = self.attention(hidden_states_norm, head_mask)
  322. # first residual connection
  323. hidden_states = attention_output + hidden_states
  324. # in Yolos, layernorm is also applied after self-attention
  325. layer_output = self.layernorm_after(hidden_states)
  326. layer_output = self.intermediate(layer_output)
  327. # second residual connection is done here
  328. layer_output = self.output(layer_output, hidden_states)
  329. return layer_output
  330. class YolosEncoder(nn.Module):
  331. def __init__(self, config: YolosConfig) -> None:
  332. super().__init__()
  333. self.config = config
  334. self.layer = nn.ModuleList([YolosLayer(config) for _ in range(config.num_hidden_layers)])
  335. self.gradient_checkpointing = False
  336. seq_length = (
  337. 1 + (config.image_size[0] * config.image_size[1] // config.patch_size**2) + config.num_detection_tokens
  338. )
  339. self.mid_position_embeddings = (
  340. nn.Parameter(
  341. torch.zeros(
  342. config.num_hidden_layers - 1,
  343. 1,
  344. seq_length,
  345. config.hidden_size,
  346. )
  347. )
  348. if config.use_mid_position_embeddings
  349. else None
  350. )
  351. self.interpolation = InterpolateMidPositionEmbeddings(config) if config.use_mid_position_embeddings else None
  352. def forward(
  353. self,
  354. hidden_states: torch.Tensor,
  355. height: int,
  356. width: int,
  357. head_mask: Optional[torch.Tensor] = None,
  358. ) -> BaseModelOutput:
  359. if self.config.use_mid_position_embeddings:
  360. interpolated_mid_position_embeddings = self.interpolation(self.mid_position_embeddings, (height, width))
  361. for i, layer_module in enumerate(self.layer):
  362. layer_head_mask = head_mask[i] if head_mask is not None else None
  363. hidden_states = layer_module(hidden_states, layer_head_mask)
  364. if self.config.use_mid_position_embeddings:
  365. if i < (self.config.num_hidden_layers - 1):
  366. hidden_states = hidden_states + interpolated_mid_position_embeddings[i]
  367. return BaseModelOutput(last_hidden_state=hidden_states)
  368. @auto_docstring
  369. class YolosPreTrainedModel(PreTrainedModel):
  370. config: YolosConfig
  371. base_model_prefix = "vit"
  372. main_input_name = "pixel_values"
  373. supports_gradient_checkpointing = True
  374. _no_split_modules = []
  375. _supports_sdpa = True
  376. _supports_flash_attn = True
  377. _supports_flex_attn = True
  378. _supports_attention_backend = True
  379. _can_record_outputs = {
  380. "hidden_states": YolosLayer,
  381. "attentions": YolosSelfAttention,
  382. }
  383. def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
  384. """Initialize the weights"""
  385. if isinstance(module, (nn.Linear, nn.Conv2d)):
  386. # Slightly different from the TF version which uses truncated_normal for initialization
  387. # cf https://github.com/pytorch/pytorch/pull/5617
  388. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  389. if module.bias is not None:
  390. module.bias.data.zero_()
  391. elif isinstance(module, nn.LayerNorm):
  392. module.bias.data.zero_()
  393. module.weight.data.fill_(1.0)
  394. @auto_docstring
  395. class YolosModel(YolosPreTrainedModel):
  396. def __init__(self, config: YolosConfig, add_pooling_layer: bool = True):
  397. r"""
  398. add_pooling_layer (bool, *optional*, defaults to `True`):
  399. Whether to add a pooling layer
  400. """
  401. super().__init__(config)
  402. self.config = config
  403. self.embeddings = YolosEmbeddings(config)
  404. self.encoder = YolosEncoder(config)
  405. self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  406. self.pooler = YolosPooler(config) if add_pooling_layer else None
  407. # Initialize weights and apply final processing
  408. self.post_init()
  409. def get_input_embeddings(self) -> YolosPatchEmbeddings:
  410. return self.embeddings.patch_embeddings
  411. def _prune_heads(self, heads_to_prune: dict[int, list[int]]) -> None:
  412. """
  413. Prunes heads of the model.
  414. Args:
  415. heads_to_prune (`dict`):
  416. See base class `PreTrainedModel`. The input dictionary must have the following format: {layer_num:
  417. list of heads to prune in this layer}
  418. """
  419. for layer, heads in heads_to_prune.items():
  420. self.encoder.layer[layer].attention.prune_heads(heads)
  421. @check_model_inputs(tie_last_hidden_states=False)
  422. @auto_docstring
  423. def forward(
  424. self,
  425. pixel_values: Optional[torch.Tensor] = None,
  426. head_mask: Optional[torch.Tensor] = None,
  427. **kwargs: Unpack[TransformersKwargs],
  428. ) -> BaseModelOutputWithPooling:
  429. if pixel_values is None:
  430. raise ValueError("You have to specify pixel_values")
  431. # Prepare head mask if needed
  432. # 1.0 in head_mask indicate we keep the head
  433. # attention_probs has shape bsz x n_heads x N x N
  434. # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
  435. # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
  436. head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
  437. embedding_output = self.embeddings(pixel_values)
  438. height, width = pixel_values.shape[-2:]
  439. encoder_outputs: BaseModelOutput = self.encoder(
  440. embedding_output, height=height, width=width, head_mask=head_mask
  441. )
  442. sequence_output = encoder_outputs.last_hidden_state
  443. sequence_output = self.layernorm(sequence_output)
  444. pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
  445. return BaseModelOutputWithPooling(last_hidden_state=sequence_output, pooler_output=pooled_output)
  446. class YolosPooler(nn.Module):
  447. def __init__(self, config: YolosConfig):
  448. super().__init__()
  449. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  450. self.activation = nn.Tanh()
  451. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  452. # We "pool" the model by simply taking the hidden state corresponding
  453. # to the first token.
  454. first_token_tensor = hidden_states[:, 0]
  455. pooled_output = self.dense(first_token_tensor)
  456. pooled_output = self.activation(pooled_output)
  457. return pooled_output
  458. # Copied from transformers.models.detr.modeling_detr.DetrMLPPredictionHead with Detr->Yolos
  459. class YolosMLPPredictionHead(nn.Module):
  460. """
  461. Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates,
  462. height and width of a bounding box w.r.t. an image.
  463. Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py
  464. """
  465. def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
  466. super().__init__()
  467. self.num_layers = num_layers
  468. h = [hidden_dim] * (num_layers - 1)
  469. self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
  470. def forward(self, x):
  471. for i, layer in enumerate(self.layers):
  472. x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
  473. return x
  474. @auto_docstring(
  475. custom_intro="""
  476. YOLOS Model (consisting of a ViT encoder) with object detection heads on top, for tasks such as COCO detection.
  477. """
  478. )
  479. class YolosForObjectDetection(YolosPreTrainedModel):
  480. def __init__(self, config: YolosConfig):
  481. super().__init__(config)
  482. # YOLOS (ViT) encoder model
  483. self.vit = YolosModel(config, add_pooling_layer=False)
  484. # Object detection heads
  485. # We add one for the "no object" class
  486. self.class_labels_classifier = YolosMLPPredictionHead(
  487. input_dim=config.hidden_size, hidden_dim=config.hidden_size, output_dim=config.num_labels + 1, num_layers=3
  488. )
  489. self.bbox_predictor = YolosMLPPredictionHead(
  490. input_dim=config.hidden_size, hidden_dim=config.hidden_size, output_dim=4, num_layers=3
  491. )
  492. # Initialize weights and apply final processing
  493. self.post_init()
  494. # taken from https://github.com/facebookresearch/detr/blob/master/models/detr.py
  495. @torch.jit.unused
  496. def _set_aux_loss(self, outputs_class, outputs_coord):
  497. # this is a workaround to make torchscript happy, as torchscript
  498. # doesn't support dictionary with non-homogeneous values, such
  499. # as a dict having both a Tensor and a list.
  500. return [{"logits": a, "pred_boxes": b} for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
  501. @can_return_tuple
  502. @auto_docstring
  503. def forward(
  504. self,
  505. pixel_values: torch.FloatTensor,
  506. labels: Optional[list[dict]] = None,
  507. **kwargs: Unpack[TransformersKwargs],
  508. ) -> YolosObjectDetectionOutput:
  509. r"""
  510. labels (`list[Dict]` of len `(batch_size,)`, *optional*):
  511. Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the
  512. following 2 keys: `'class_labels'` and `'boxes'` (the class labels and bounding boxes of an image in the
  513. batch respectively). The class labels themselves should be a `torch.LongTensor` of len `(number of bounding
  514. boxes in the image,)` and the boxes a `torch.FloatTensor` of shape `(number of bounding boxes in the image,
  515. 4)`.
  516. Examples:
  517. ```python
  518. >>> from transformers import AutoImageProcessor, AutoModelForObjectDetection
  519. >>> import torch
  520. >>> from PIL import Image
  521. >>> import requests
  522. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  523. >>> image = Image.open(requests.get(url, stream=True).raw)
  524. >>> image_processor = AutoImageProcessor.from_pretrained("hustvl/yolos-tiny")
  525. >>> model = AutoModelForObjectDetection.from_pretrained("hustvl/yolos-tiny")
  526. >>> inputs = image_processor(images=image, return_tensors="pt")
  527. >>> outputs = model(**inputs)
  528. >>> # convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax)
  529. >>> target_sizes = torch.tensor([image.size[::-1]])
  530. >>> results = image_processor.post_process_object_detection(outputs, threshold=0.9, target_sizes=target_sizes)[
  531. ... 0
  532. ... ]
  533. >>> for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
  534. ... box = [round(i, 2) for i in box.tolist()]
  535. ... print(
  536. ... f"Detected {model.config.id2label[label.item()]} with confidence "
  537. ... f"{round(score.item(), 3)} at location {box}"
  538. ... )
  539. Detected remote with confidence 0.991 at location [46.48, 72.78, 178.98, 119.3]
  540. Detected remote with confidence 0.908 at location [336.48, 79.27, 368.23, 192.36]
  541. Detected cat with confidence 0.934 at location [337.18, 18.06, 638.14, 373.09]
  542. Detected cat with confidence 0.979 at location [10.93, 53.74, 313.41, 470.67]
  543. Detected remote with confidence 0.974 at location [41.63, 72.23, 178.09, 119.99]
  544. ```"""
  545. # First, sent images through YOLOS base model to obtain hidden states
  546. outputs: BaseModelOutputWithPooling = self.vit(pixel_values, **kwargs)
  547. sequence_output = outputs.last_hidden_state
  548. # Take the final hidden states of the detection tokens
  549. sequence_output = sequence_output[:, -self.config.num_detection_tokens :, :]
  550. # Class logits + predicted bounding boxes
  551. logits = self.class_labels_classifier(sequence_output)
  552. pred_boxes = self.bbox_predictor(sequence_output).sigmoid()
  553. loss, loss_dict, auxiliary_outputs = None, None, None
  554. if labels is not None:
  555. outputs_class, outputs_coord = None, None
  556. if self.config.auxiliary_loss:
  557. intermediate = outputs.hidden_states
  558. outputs_class = self.class_labels_classifier(intermediate)
  559. outputs_coord = self.bbox_predictor(intermediate).sigmoid()
  560. loss, loss_dict, auxiliary_outputs = self.loss_function(
  561. logits, labels, self.device, pred_boxes, self.config, outputs_class, outputs_coord
  562. )
  563. return YolosObjectDetectionOutput(
  564. loss=loss,
  565. loss_dict=loss_dict,
  566. logits=logits,
  567. pred_boxes=pred_boxes,
  568. auxiliary_outputs=auxiliary_outputs,
  569. last_hidden_state=outputs.last_hidden_state,
  570. hidden_states=outputs.hidden_states,
  571. attentions=outputs.attentions,
  572. )
  573. __all__ = ["YolosForObjectDetection", "YolosModel", "YolosPreTrainedModel"]