modeling_vivit.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689
  1. # coding=utf-8
  2. # Copyright 2023 Google AI and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch ViViT model."""
  16. from typing import Callable, Optional
  17. import torch
  18. from torch import nn
  19. from ...activations import ACT2FN
  20. from ...modeling_layers import GradientCheckpointingLayer
  21. from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
  22. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  23. from ...processing_utils import Unpack
  24. from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
  25. from ...utils import TransformersKwargs, auto_docstring, logging, torch_int
  26. from ...utils.generic import can_return_tuple, check_model_inputs
  27. from .configuration_vivit import VivitConfig
  28. logger = logging.get_logger(__name__)
  29. class VivitTubeletEmbeddings(nn.Module):
  30. """
  31. Construct Vivit Tubelet embeddings.
  32. This module turns a batch of videos of shape (batch_size, num_frames, num_channels, height, width) into a tensor of
  33. shape (batch_size, seq_len, hidden_size) to be consumed by a Transformer encoder.
  34. The seq_len (the number of patches) equals (number of frames // tubelet_size[0]) * (height // tubelet_size[1]) *
  35. (width // tubelet_size[2]).
  36. """
  37. def __init__(self, config: VivitConfig):
  38. super().__init__()
  39. self.num_frames = config.num_frames
  40. self.image_size = config.image_size
  41. self.patch_size = config.tubelet_size
  42. self.num_patches = (
  43. (self.image_size // self.patch_size[2])
  44. * (self.image_size // self.patch_size[1])
  45. * (self.num_frames // self.patch_size[0])
  46. )
  47. self.embed_dim = config.hidden_size
  48. self.projection = nn.Conv3d(
  49. config.num_channels, config.hidden_size, kernel_size=config.tubelet_size, stride=config.tubelet_size
  50. )
  51. def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
  52. batch_size, num_frames, num_channels, height, width = pixel_values.shape
  53. if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size):
  54. raise ValueError(
  55. f"Image image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
  56. )
  57. # permute to (batch_size, num_channels, num_frames, height, width)
  58. pixel_values = pixel_values.permute(0, 2, 1, 3, 4)
  59. x = self.projection(pixel_values)
  60. # out_batch_size, out_num_channels, out_num_frames, out_height, out_width = x.shape
  61. # flattens time and space dimensions, transposes to (out_batch_size, flat_tokens, out_num_channels)
  62. x = x.flatten(2).transpose(1, 2)
  63. return x
  64. class VivitEmbeddings(nn.Module):
  65. """
  66. Vivit Embeddings.
  67. Creates embeddings from a video using VivitTubeletEmbeddings, adds CLS token and positional embeddings.
  68. """
  69. def __init__(self, config: VivitConfig):
  70. super().__init__()
  71. self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
  72. self.patch_embeddings = VivitTubeletEmbeddings(config)
  73. self.position_embeddings = nn.Parameter(
  74. torch.zeros(1, self.patch_embeddings.num_patches + 1, config.hidden_size)
  75. )
  76. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  77. self.patch_size = config.tubelet_size[1:]
  78. self.config = config
  79. # Adapted from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
  80. def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
  81. """
  82. This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
  83. images. This method is also adapted to support torch.jit tracing.
  84. Adapted from:
  85. - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
  86. - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
  87. """
  88. num_patches = embeddings.shape[1] - 1
  89. num_positions = self.position_embeddings.shape[1] - 1
  90. # always interpolate when tracing to ensure the exported model works for dynamic input shapes
  91. if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
  92. return self.position_embeddings
  93. class_pos_embed = self.position_embeddings[:, :1]
  94. patch_pos_embed = self.position_embeddings[:, 1:]
  95. dim = embeddings.shape[-1]
  96. new_height = height // self.patch_size[0]
  97. new_width = width // self.patch_size[1]
  98. sqrt_num_positions = torch_int(num_positions**0.5)
  99. patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
  100. patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
  101. patch_pos_embed = nn.functional.interpolate(
  102. patch_pos_embed,
  103. size=(new_height, new_width),
  104. mode="bicubic",
  105. align_corners=False,
  106. )
  107. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  108. return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
  109. def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
  110. batch_size, num_frames, num_channels, height, width = pixel_values.shape
  111. embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
  112. cls_tokens = self.cls_token.tile([batch_size, 1, 1])
  113. embeddings = torch.cat((cls_tokens, embeddings), dim=1)
  114. # add positional encoding to each token
  115. if interpolate_pos_encoding:
  116. embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
  117. else:
  118. embeddings = embeddings + self.position_embeddings
  119. embeddings = self.dropout(embeddings)
  120. return embeddings
  121. # Copied from transformers.models.vit.modeling_vit.eager_attention_forward
  122. def eager_attention_forward(
  123. module: nn.Module,
  124. query: torch.Tensor,
  125. key: torch.Tensor,
  126. value: torch.Tensor,
  127. attention_mask: Optional[torch.Tensor],
  128. scaling: float,
  129. dropout: float = 0.0,
  130. **kwargs,
  131. ):
  132. # Take the dot product between "query" and "key" to get the raw attention scores.
  133. attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
  134. # Normalize the attention scores to probabilities.
  135. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  136. # This is actually dropping out entire tokens to attend to, which might
  137. # seem a bit unusual, but is taken from the original Transformer paper.
  138. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  139. # Mask heads if we want to
  140. if attention_mask is not None:
  141. attn_weights = attn_weights * attention_mask
  142. attn_output = torch.matmul(attn_weights, value)
  143. attn_output = attn_output.transpose(1, 2).contiguous()
  144. return attn_output, attn_weights
  145. # Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->Vivit
  146. class VivitSelfAttention(nn.Module):
  147. def __init__(self, config: VivitConfig):
  148. super().__init__()
  149. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  150. raise ValueError(
  151. f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
  152. f"heads {config.num_attention_heads}."
  153. )
  154. self.config = config
  155. self.num_attention_heads = config.num_attention_heads
  156. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  157. self.all_head_size = self.num_attention_heads * self.attention_head_size
  158. self.dropout_prob = config.attention_probs_dropout_prob
  159. self.scaling = self.attention_head_size**-0.5
  160. self.is_causal = False
  161. self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
  162. self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
  163. self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
  164. def forward(
  165. self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None
  166. ) -> tuple[torch.Tensor, torch.Tensor]:
  167. batch_size = hidden_states.shape[0]
  168. new_shape = batch_size, -1, self.num_attention_heads, self.attention_head_size
  169. key_layer = self.key(hidden_states).view(*new_shape).transpose(1, 2)
  170. value_layer = self.value(hidden_states).view(*new_shape).transpose(1, 2)
  171. query_layer = self.query(hidden_states).view(*new_shape).transpose(1, 2)
  172. attention_interface: Callable = eager_attention_forward
  173. if self.config._attn_implementation != "eager":
  174. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  175. context_layer, attention_probs = attention_interface(
  176. self,
  177. query_layer,
  178. key_layer,
  179. value_layer,
  180. head_mask,
  181. is_causal=self.is_causal,
  182. scaling=self.scaling,
  183. dropout=0.0 if not self.training else self.dropout_prob,
  184. )
  185. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  186. context_layer = context_layer.reshape(new_context_layer_shape)
  187. return context_layer, attention_probs
  188. # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Vivit
  189. class VivitSelfOutput(nn.Module):
  190. """
  191. The residual connection is defined in VivitLayer instead of here (as is the case with other models), due to the
  192. layernorm applied before each block.
  193. """
  194. def __init__(self, config: VivitConfig):
  195. super().__init__()
  196. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  197. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  198. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  199. hidden_states = self.dense(hidden_states)
  200. hidden_states = self.dropout(hidden_states)
  201. return hidden_states
  202. # Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->Vivit
  203. class VivitAttention(nn.Module):
  204. def __init__(self, config: VivitConfig):
  205. super().__init__()
  206. self.attention = VivitSelfAttention(config)
  207. self.output = VivitSelfOutput(config)
  208. self.pruned_heads = set()
  209. def prune_heads(self, heads: set[int]):
  210. if len(heads) == 0:
  211. return
  212. heads, index = find_pruneable_heads_and_indices(
  213. heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
  214. )
  215. # Prune linear layers
  216. self.attention.query = prune_linear_layer(self.attention.query, index)
  217. self.attention.key = prune_linear_layer(self.attention.key, index)
  218. self.attention.value = prune_linear_layer(self.attention.value, index)
  219. self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
  220. # Update hyper params and store pruned heads
  221. self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
  222. self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
  223. self.pruned_heads = self.pruned_heads.union(heads)
  224. def forward(self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
  225. self_attn_output, _ = self.attention(hidden_states, head_mask)
  226. output = self.output(self_attn_output, hidden_states)
  227. return output
  228. class VivitIntermediate(nn.Module):
  229. def __init__(self, config: VivitConfig):
  230. super().__init__()
  231. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  232. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  233. if isinstance(config.hidden_act, str):
  234. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  235. else:
  236. self.intermediate_act_fn = config.hidden_act
  237. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  238. hidden_states = self.dense(hidden_states)
  239. hidden_states = self.intermediate_act_fn(hidden_states)
  240. hidden_states = self.dropout(hidden_states)
  241. return hidden_states
  242. class VivitOutput(nn.Module):
  243. def __init__(self, config: VivitConfig):
  244. super().__init__()
  245. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  246. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  247. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  248. hidden_states = self.dense(hidden_states)
  249. hidden_states = self.dropout(hidden_states)
  250. hidden_states = hidden_states + input_tensor
  251. return hidden_states
  252. class VivitLayer(GradientCheckpointingLayer):
  253. """This corresponds to the EncoderBlock class in the scenic/vivit implementation."""
  254. def __init__(self, config: VivitConfig):
  255. super().__init__()
  256. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  257. self.seq_len_dim = 1
  258. self.attention = VivitAttention(config)
  259. self.intermediate = VivitIntermediate(config)
  260. self.output = VivitOutput(config)
  261. self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  262. self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  263. def forward(self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
  264. hidden_states_norm = self.layernorm_before(hidden_states)
  265. attention_output = self.attention(hidden_states_norm, head_mask)
  266. # first residual connection
  267. hidden_states = attention_output + hidden_states
  268. # in Vivit, layernorm is also applied after self-attention
  269. layer_output = self.layernorm_after(hidden_states)
  270. layer_output = self.intermediate(layer_output)
  271. # second residual connection is done here
  272. layer_output = self.output(layer_output, hidden_states)
  273. return layer_output
  274. class VivitEncoder(nn.Module):
  275. def __init__(self, config: VivitConfig):
  276. super().__init__()
  277. self.config = config
  278. self.layer = nn.ModuleList([VivitLayer(config) for _ in range(config.num_hidden_layers)])
  279. self.gradient_checkpointing = False
  280. def forward(self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None) -> BaseModelOutput:
  281. for i, layer_module in enumerate(self.layer):
  282. layer_head_mask = head_mask[i] if head_mask is not None else None
  283. hidden_states = layer_module(hidden_states, layer_head_mask)
  284. return BaseModelOutput(last_hidden_state=hidden_states)
  285. class VivitPooler(nn.Module):
  286. def __init__(self, config: VivitConfig):
  287. super().__init__()
  288. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  289. self.activation = nn.Tanh()
  290. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  291. # We "pool" the model by simply taking the hidden state corresponding
  292. # to the first token.
  293. first_token_tensor = hidden_states[:, 0]
  294. pooled_output = self.dense(first_token_tensor)
  295. pooled_output = self.activation(pooled_output)
  296. return pooled_output
  297. @auto_docstring
  298. class VivitPreTrainedModel(PreTrainedModel):
  299. config: VivitConfig
  300. base_model_prefix = "vivit"
  301. main_input_name = "pixel_values"
  302. supports_gradient_checkpointing = True
  303. _no_split_modules = []
  304. _supports_sdpa = True
  305. _supports_flash_attn = True
  306. _supports_flex_attn = True
  307. _supports_attention_backend = True
  308. _can_record_outputs = {
  309. "hidden_states": VivitLayer,
  310. "attentions": VivitSelfAttention,
  311. }
  312. def _init_weights(self, module):
  313. """Initialize the weights"""
  314. if isinstance(module, (nn.Linear, nn.Conv3d)):
  315. # Slightly different from the TF version which uses truncated_normal for initialization
  316. # cf https://github.com/pytorch/pytorch/pull/5617
  317. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  318. if module.bias is not None:
  319. module.bias.data.zero_()
  320. elif isinstance(module, nn.Embedding):
  321. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  322. if module.padding_idx is not None:
  323. module.weight.data[module.padding_idx].zero_()
  324. elif isinstance(module, nn.LayerNorm):
  325. module.bias.data.zero_()
  326. module.weight.data.fill_(1.0)
  327. elif isinstance(module, VivitEmbeddings):
  328. module.cls_token.data.zero_()
  329. module.position_embeddings.data.zero_()
  330. @auto_docstring
  331. class VivitModel(VivitPreTrainedModel):
  332. def __init__(self, config: VivitConfig, add_pooling_layer: bool = True):
  333. r"""
  334. add_pooling_layer (bool, *optional*, defaults to `True`):
  335. Whether to add a pooling layer
  336. """
  337. super().__init__(config)
  338. self.config = config
  339. self.embeddings = VivitEmbeddings(config)
  340. self.encoder = VivitEncoder(config)
  341. self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  342. self.pooler = VivitPooler(config) if add_pooling_layer else None
  343. # Initialize weights and apply final processing
  344. self.post_init()
  345. def get_input_embeddings(self):
  346. return self.embeddings.patch_embeddings
  347. def _prune_heads(self, heads_to_prune):
  348. """
  349. Prunes heads of the model.
  350. Args:
  351. heads_to_prune:
  352. dict of {layer_num: list of heads to prune in this layer}
  353. """
  354. for layer, heads in heads_to_prune.items():
  355. self.encoder.layer[layer].attention.prune_heads(heads)
  356. @check_model_inputs(tie_last_hidden_states=False)
  357. @auto_docstring
  358. def forward(
  359. self,
  360. pixel_values: Optional[torch.FloatTensor] = None,
  361. head_mask: Optional[torch.FloatTensor] = None,
  362. interpolate_pos_encoding: bool = False,
  363. **kwargs: Unpack[TransformersKwargs],
  364. ) -> BaseModelOutputWithPooling:
  365. r"""
  366. Examples:
  367. ```python
  368. >>> import av
  369. >>> import numpy as np
  370. >>> from transformers import VivitImageProcessor, VivitModel
  371. >>> from huggingface_hub import hf_hub_download
  372. >>> np.random.seed(0)
  373. >>> def read_video_pyav(container, indices):
  374. ... '''
  375. ... Decode the video with PyAV decoder.
  376. ... Args:
  377. ... container (`av.container.input.InputContainer`): PyAV container.
  378. ... indices (`list[int]`): List of frame indices to decode.
  379. ... Returns:
  380. ... result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
  381. ... '''
  382. ... frames = []
  383. ... container.seek(0)
  384. ... start_index = indices[0]
  385. ... end_index = indices[-1]
  386. ... for i, frame in enumerate(container.decode(video=0)):
  387. ... if i > end_index:
  388. ... break
  389. ... if i >= start_index and i in indices:
  390. ... frames.append(frame)
  391. ... return np.stack([x.to_ndarray(format="rgb24") for x in frames])
  392. >>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
  393. ... '''
  394. ... Sample a given number of frame indices from the video.
  395. ... Args:
  396. ... clip_len (`int`): Total number of frames to sample.
  397. ... frame_sample_rate (`int`): Sample every n-th frame.
  398. ... seg_len (`int`): Maximum allowed index of sample's last frame.
  399. ... Returns:
  400. ... indices (`list[int]`): List of sampled frame indices
  401. ... '''
  402. ... converted_len = int(clip_len * frame_sample_rate)
  403. ... end_idx = np.random.randint(converted_len, seg_len)
  404. ... start_idx = end_idx - converted_len
  405. ... indices = np.linspace(start_idx, end_idx, num=clip_len)
  406. ... indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
  407. ... return indices
  408. >>> # video clip consists of 300 frames (10 seconds at 30 FPS)
  409. >>> file_path = hf_hub_download(
  410. ... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset"
  411. ... )
  412. >>> container = av.open(file_path)
  413. >>> # sample 32 frames
  414. >>> indices = sample_frame_indices(clip_len=32, frame_sample_rate=1, seg_len=container.streams.video[0].frames)
  415. >>> video = read_video_pyav(container=container, indices=indices)
  416. >>> image_processor = VivitImageProcessor.from_pretrained("google/vivit-b-16x2-kinetics400")
  417. >>> model = VivitModel.from_pretrained("google/vivit-b-16x2-kinetics400")
  418. >>> # prepare video for the model
  419. >>> inputs = image_processor(list(video), return_tensors="pt")
  420. >>> # forward pass
  421. >>> outputs = model(**inputs)
  422. >>> last_hidden_states = outputs.last_hidden_state
  423. >>> list(last_hidden_states.shape)
  424. [1, 3137, 768]
  425. ```"""
  426. if pixel_values is None:
  427. raise ValueError("You have to specify pixel_values")
  428. head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
  429. embedding_output = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
  430. encoder_outputs: BaseModelOutput = self.encoder(embedding_output, head_mask=head_mask)
  431. sequence_output = encoder_outputs.last_hidden_state
  432. sequence_output = self.layernorm(sequence_output)
  433. pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
  434. return BaseModelOutputWithPooling(last_hidden_state=sequence_output, pooler_output=pooled_output)
  435. @auto_docstring(
  436. custom_intro="""
  437. ViViT Transformer model with a video classification head on top (a linear layer on top of the final hidden state of the
  438. [CLS] token) e.g. for Kinetics-400.
  439. <Tip>
  440. Note that it's possible to fine-tune ViT on higher resolution images than the ones it has been trained on, by
  441. setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
  442. position embeddings to the higher resolution.
  443. </Tip>
  444. """
  445. )
  446. class VivitForVideoClassification(VivitPreTrainedModel):
  447. def __init__(self, config: VivitConfig):
  448. super().__init__(config)
  449. self.num_labels = config.num_labels
  450. self.vivit = VivitModel(config, add_pooling_layer=False)
  451. # Classifier head
  452. self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
  453. # Initialize weights and apply final processing
  454. self.post_init()
  455. @can_return_tuple
  456. @auto_docstring
  457. def forward(
  458. self,
  459. pixel_values: Optional[torch.FloatTensor] = None,
  460. head_mask: Optional[torch.FloatTensor] = None,
  461. labels: Optional[torch.LongTensor] = None,
  462. interpolate_pos_encoding: bool = False,
  463. **kwargs: Unpack[TransformersKwargs],
  464. ) -> ImageClassifierOutput:
  465. r"""
  466. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  467. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  468. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  469. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  470. Examples:
  471. ```python
  472. >>> import av
  473. >>> import numpy as np
  474. >>> import torch
  475. >>> from transformers import VivitImageProcessor, VivitForVideoClassification
  476. >>> from huggingface_hub import hf_hub_download
  477. >>> np.random.seed(0)
  478. >>> def read_video_pyav(container, indices):
  479. ... '''
  480. ... Decode the video with PyAV decoder.
  481. ... Args:
  482. ... container (`av.container.input.InputContainer`): PyAV container.
  483. ... indices (`list[int]`): List of frame indices to decode.
  484. ... Returns:
  485. ... result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
  486. ... '''
  487. ... frames = []
  488. ... container.seek(0)
  489. ... start_index = indices[0]
  490. ... end_index = indices[-1]
  491. ... for i, frame in enumerate(container.decode(video=0)):
  492. ... if i > end_index:
  493. ... break
  494. ... if i >= start_index and i in indices:
  495. ... frames.append(frame)
  496. ... return np.stack([x.to_ndarray(format="rgb24") for x in frames])
  497. >>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
  498. ... '''
  499. ... Sample a given number of frame indices from the video.
  500. ... Args:
  501. ... clip_len (`int`): Total number of frames to sample.
  502. ... frame_sample_rate (`int`): Sample every n-th frame.
  503. ... seg_len (`int`): Maximum allowed index of sample's last frame.
  504. ... Returns:
  505. ... indices (`list[int]`): List of sampled frame indices
  506. ... '''
  507. ... converted_len = int(clip_len * frame_sample_rate)
  508. ... end_idx = np.random.randint(converted_len, seg_len)
  509. ... start_idx = end_idx - converted_len
  510. ... indices = np.linspace(start_idx, end_idx, num=clip_len)
  511. ... indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
  512. ... return indices
  513. >>> # video clip consists of 300 frames (10 seconds at 30 FPS)
  514. >>> file_path = hf_hub_download(
  515. ... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset"
  516. ... )
  517. >>> container = av.open(file_path)
  518. >>> # sample 32 frames
  519. >>> indices = sample_frame_indices(clip_len=32, frame_sample_rate=4, seg_len=container.streams.video[0].frames)
  520. >>> video = read_video_pyav(container=container, indices=indices)
  521. >>> image_processor = VivitImageProcessor.from_pretrained("google/vivit-b-16x2-kinetics400")
  522. >>> model = VivitForVideoClassification.from_pretrained("google/vivit-b-16x2-kinetics400")
  523. >>> inputs = image_processor(list(video), return_tensors="pt")
  524. >>> with torch.no_grad():
  525. ... outputs = model(**inputs)
  526. ... logits = outputs.logits
  527. >>> # model predicts one of the 400 Kinetics-400 classes
  528. >>> predicted_label = logits.argmax(-1).item()
  529. >>> print(model.config.id2label[predicted_label])
  530. LABEL_116
  531. ```"""
  532. outputs: BaseModelOutput = self.vivit(
  533. pixel_values, head_mask=head_mask, interpolate_pos_encoding=interpolate_pos_encoding, **kwargs
  534. )
  535. sequence_output = outputs.last_hidden_state
  536. logits = self.classifier(sequence_output[:, 0, :])
  537. loss = None
  538. if labels is not None:
  539. loss = self.loss_function(labels, logits, self.config, **kwargs)
  540. return ImageClassifierOutput(
  541. loss=loss,
  542. logits=logits,
  543. hidden_states=outputs.hidden_states,
  544. attentions=outputs.attentions,
  545. )
  546. __all__ = ["VivitModel", "VivitPreTrainedModel", "VivitForVideoClassification"]