modeling_videomae.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906
  1. # coding=utf-8
  2. # Copyright 2022 Multimedia Computing Group, Nanjing University and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch VideoMAE (masked autoencoder) model."""
  16. import collections.abc
  17. from copy import deepcopy
  18. from dataclasses import dataclass
  19. from typing import Callable, Optional
  20. import numpy as np
  21. import torch
  22. from torch import nn
  23. from torch.nn import MSELoss
  24. from ...activations import ACT2FN
  25. from ...modeling_layers import GradientCheckpointingLayer
  26. from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput
  27. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  28. from ...processing_utils import Unpack
  29. from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
  30. from ...utils import ModelOutput, TransformersKwargs, auto_docstring, logging
  31. from ...utils.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  32. from ...utils.generic import can_return_tuple, check_model_inputs
  33. from .configuration_videomae import VideoMAEConfig
  34. logger = logging.get_logger(__name__)
  35. @dataclass
  36. @auto_docstring(
  37. custom_intro="""
  38. Class for VideoMAEDecoder's outputs, with potential hidden states and attentions.
  39. """
  40. )
  41. class VideoMAEDecoderOutput(ModelOutput):
  42. r"""
  43. logits (`torch.FloatTensor` of shape `(batch_size, patch_size ** 2 * num_channels)`):
  44. Pixel reconstruction logits.
  45. """
  46. logits: Optional[torch.FloatTensor] = None
  47. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  48. attentions: Optional[tuple[torch.FloatTensor]] = None
  49. @dataclass
  50. @auto_docstring(
  51. custom_intro="""
  52. Class for VideoMAEForPreTraining's outputs, with potential hidden states and attentions.
  53. """
  54. )
  55. class VideoMAEForPreTrainingOutput(ModelOutput):
  56. r"""
  57. loss (`torch.FloatTensor` of shape `(1,)`):
  58. Pixel reconstruction loss.
  59. logits (`torch.FloatTensor` of shape `(batch_size, patch_size ** 2 * num_channels)`):
  60. Pixel reconstruction logits.
  61. """
  62. loss: Optional[torch.FloatTensor] = None
  63. logits: Optional[torch.FloatTensor] = None
  64. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  65. attentions: Optional[tuple[torch.FloatTensor]] = None
  66. # sin-cos position encoding
  67. # https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/transformer/Models.py#L31
  68. def get_sinusoid_encoding_table(n_position, d_hid):
  69. """Sinusoid position encoding table"""
  70. # TODO: make it with torch instead of numpy
  71. def get_position_angle_vec(position):
  72. return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
  73. sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
  74. sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
  75. sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
  76. return torch.FloatTensor(sinusoid_table).unsqueeze(0)
  77. class VideoMAEEmbeddings(nn.Module):
  78. """
  79. Construct the patch and position embeddings.
  80. """
  81. def __init__(self, config):
  82. super().__init__()
  83. self.patch_embeddings = VideoMAEPatchEmbeddings(config)
  84. self.num_patches = self.patch_embeddings.num_patches
  85. # fixed sin-cos embedding
  86. self.position_embeddings = get_sinusoid_encoding_table(self.num_patches, config.hidden_size)
  87. self.config = config
  88. def forward(self, pixel_values, bool_masked_pos):
  89. # create patch embeddings
  90. embeddings = self.patch_embeddings(pixel_values)
  91. # add position embeddings
  92. embeddings = embeddings + self.position_embeddings.detach().type_as(embeddings).to(
  93. device=embeddings.device, copy=True
  94. )
  95. # only keep visible patches
  96. # ~bool_masked_pos means visible
  97. if bool_masked_pos is not None:
  98. batch_size, _, num_channels = embeddings.shape
  99. embeddings = embeddings[~bool_masked_pos]
  100. embeddings = embeddings.reshape(batch_size, -1, num_channels)
  101. return embeddings
  102. class VideoMAEPatchEmbeddings(nn.Module):
  103. """
  104. Video to Patch Embedding. This module turns a batch of videos of shape (batch_size, num_frames, num_channels,
  105. height, width) into a tensor of shape (batch_size, seq_len, hidden_size) to be consumed by a Transformer encoder.
  106. The seq_len (the number of patches) equals (number of frames // tubelet_size) * (height // patch_size) * (width //
  107. patch_size).
  108. """
  109. def __init__(self, config):
  110. super().__init__()
  111. image_size = config.image_size
  112. patch_size = config.patch_size
  113. num_channels = config.num_channels
  114. hidden_size = config.hidden_size
  115. num_frames = config.num_frames
  116. tubelet_size = config.tubelet_size
  117. image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
  118. patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
  119. self.image_size = image_size
  120. self.patch_size = patch_size
  121. self.tubelet_size = int(tubelet_size)
  122. num_patches = (
  123. (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) * (num_frames // self.tubelet_size)
  124. )
  125. self.num_channels = num_channels
  126. self.num_patches = num_patches
  127. self.projection = nn.Conv3d(
  128. in_channels=num_channels,
  129. out_channels=hidden_size,
  130. kernel_size=(self.tubelet_size, patch_size[0], patch_size[1]),
  131. stride=(self.tubelet_size, patch_size[0], patch_size[1]),
  132. )
  133. def forward(self, pixel_values):
  134. batch_size, num_frames, num_channels, height, width = pixel_values.shape
  135. if num_channels != self.num_channels:
  136. raise ValueError(
  137. "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
  138. )
  139. if height != self.image_size[0] or width != self.image_size[1]:
  140. raise ValueError(
  141. f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
  142. )
  143. # permute to (batch_size, num_channels, num_frames, height, width)
  144. pixel_values = pixel_values.permute(0, 2, 1, 3, 4)
  145. embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
  146. return embeddings
  147. # Copied from transformers.models.vit.modeling_vit.eager_attention_forward
  148. def eager_attention_forward(
  149. module: nn.Module,
  150. query: torch.Tensor,
  151. key: torch.Tensor,
  152. value: torch.Tensor,
  153. attention_mask: Optional[torch.Tensor],
  154. scaling: float,
  155. dropout: float = 0.0,
  156. **kwargs,
  157. ):
  158. # Take the dot product between "query" and "key" to get the raw attention scores.
  159. attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
  160. # Normalize the attention scores to probabilities.
  161. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  162. # This is actually dropping out entire tokens to attend to, which might
  163. # seem a bit unusual, but is taken from the original Transformer paper.
  164. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  165. # Mask heads if we want to
  166. if attention_mask is not None:
  167. attn_weights = attn_weights * attention_mask
  168. attn_output = torch.matmul(attn_weights, value)
  169. attn_output = attn_output.transpose(1, 2).contiguous()
  170. return attn_output, attn_weights
  171. class VideoMAESelfAttention(nn.Module):
  172. def __init__(self, config: VideoMAEConfig) -> None:
  173. super().__init__()
  174. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  175. raise ValueError(
  176. f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
  177. f"heads {config.num_attention_heads}."
  178. )
  179. self.config = config
  180. self.num_attention_heads = config.num_attention_heads
  181. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  182. self.all_head_size = self.num_attention_heads * self.attention_head_size
  183. self.dropout_prob = config.attention_probs_dropout_prob
  184. self.scaling = self.attention_head_size**-0.5
  185. self.is_causal = False
  186. self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=False)
  187. self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=False)
  188. self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=False)
  189. if config.qkv_bias:
  190. self.q_bias = nn.Parameter(torch.zeros(self.all_head_size))
  191. self.v_bias = nn.Parameter(torch.zeros(self.all_head_size))
  192. else:
  193. self.q_bias = None
  194. self.v_bias = None
  195. def forward(self, hidden_states, head_mask: Optional[torch.Tensor] = None) -> tuple[torch.Tensor, torch.Tensor]:
  196. batch_size, seq_length, _ = hidden_states.shape
  197. k_bias = torch.zeros_like(self.v_bias, requires_grad=False) if self.q_bias is not None else None
  198. keys = nn.functional.linear(input=hidden_states, weight=self.key.weight, bias=k_bias)
  199. values = nn.functional.linear(input=hidden_states, weight=self.value.weight, bias=self.v_bias)
  200. queries = nn.functional.linear(input=hidden_states, weight=self.query.weight, bias=self.q_bias)
  201. key_layer = keys.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2)
  202. value_layer = values.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2)
  203. query_layer = queries.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2)
  204. attention_interface: Callable = eager_attention_forward
  205. if self.config._attn_implementation != "eager":
  206. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  207. context_layer, attention_probs = attention_interface(
  208. self,
  209. query_layer,
  210. key_layer,
  211. value_layer,
  212. head_mask,
  213. is_causal=self.is_causal,
  214. scaling=self.scaling,
  215. dropout=0.0 if not self.training else self.dropout_prob,
  216. )
  217. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  218. context_layer = context_layer.reshape(new_context_layer_shape)
  219. return context_layer, attention_probs
  220. # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->VideoMAE
  221. class VideoMAESelfOutput(nn.Module):
  222. """
  223. The residual connection is defined in VideoMAELayer instead of here (as is the case with other models), due to the
  224. layernorm applied before each block.
  225. """
  226. def __init__(self, config: VideoMAEConfig):
  227. super().__init__()
  228. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  229. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  230. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  231. hidden_states = self.dense(hidden_states)
  232. hidden_states = self.dropout(hidden_states)
  233. return hidden_states
  234. # Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->VideoMAE
  235. class VideoMAEAttention(nn.Module):
  236. def __init__(self, config: VideoMAEConfig):
  237. super().__init__()
  238. self.attention = VideoMAESelfAttention(config)
  239. self.output = VideoMAESelfOutput(config)
  240. self.pruned_heads = set()
  241. def prune_heads(self, heads: set[int]):
  242. if len(heads) == 0:
  243. return
  244. heads, index = find_pruneable_heads_and_indices(
  245. heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
  246. )
  247. # Prune linear layers
  248. self.attention.query = prune_linear_layer(self.attention.query, index)
  249. self.attention.key = prune_linear_layer(self.attention.key, index)
  250. self.attention.value = prune_linear_layer(self.attention.value, index)
  251. self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
  252. # Update hyper params and store pruned heads
  253. self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
  254. self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
  255. self.pruned_heads = self.pruned_heads.union(heads)
  256. def forward(self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
  257. self_attn_output, _ = self.attention(hidden_states, head_mask)
  258. output = self.output(self_attn_output, hidden_states)
  259. return output
  260. # Copied from transformers.models.vit.modeling_vit.ViTIntermediate ViT->VideoMAE
  261. class VideoMAEIntermediate(nn.Module):
  262. def __init__(self, config: VideoMAEConfig):
  263. super().__init__()
  264. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  265. if isinstance(config.hidden_act, str):
  266. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  267. else:
  268. self.intermediate_act_fn = config.hidden_act
  269. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  270. hidden_states = self.dense(hidden_states)
  271. hidden_states = self.intermediate_act_fn(hidden_states)
  272. return hidden_states
  273. # Copied from transformers.models.vit.modeling_vit.ViTOutput ViT->VideoMAE
  274. class VideoMAEOutput(nn.Module):
  275. def __init__(self, config: VideoMAEConfig):
  276. super().__init__()
  277. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  278. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  279. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  280. hidden_states = self.dense(hidden_states)
  281. hidden_states = self.dropout(hidden_states)
  282. hidden_states = hidden_states + input_tensor
  283. return hidden_states
  284. # Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->VideoMAE,VIT->VIDEOMAE
  285. class VideoMAELayer(GradientCheckpointingLayer):
  286. """This corresponds to the Block class in the timm implementation."""
  287. def __init__(self, config: VideoMAEConfig):
  288. super().__init__()
  289. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  290. self.seq_len_dim = 1
  291. self.attention = VideoMAEAttention(config)
  292. self.intermediate = VideoMAEIntermediate(config)
  293. self.output = VideoMAEOutput(config)
  294. self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  295. self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  296. def forward(self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
  297. hidden_states_norm = self.layernorm_before(hidden_states)
  298. attention_output = self.attention(hidden_states_norm, head_mask)
  299. # first residual connection
  300. hidden_states = attention_output + hidden_states
  301. # in VideoMAE, layernorm is also applied after self-attention
  302. layer_output = self.layernorm_after(hidden_states)
  303. layer_output = self.intermediate(layer_output)
  304. # second residual connection is done here
  305. layer_output = self.output(layer_output, hidden_states)
  306. return layer_output
  307. # Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->VideoMAE
  308. class VideoMAEEncoder(nn.Module):
  309. def __init__(self, config: VideoMAEConfig):
  310. super().__init__()
  311. self.config = config
  312. self.layer = nn.ModuleList([VideoMAELayer(config) for _ in range(config.num_hidden_layers)])
  313. self.gradient_checkpointing = False
  314. def forward(self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None) -> BaseModelOutput:
  315. for i, layer_module in enumerate(self.layer):
  316. layer_head_mask = head_mask[i] if head_mask is not None else None
  317. hidden_states = layer_module(hidden_states, layer_head_mask)
  318. return BaseModelOutput(last_hidden_state=hidden_states)
  319. @auto_docstring
  320. class VideoMAEPreTrainedModel(PreTrainedModel):
  321. config: VideoMAEConfig
  322. base_model_prefix = "videomae"
  323. main_input_name = "pixel_values"
  324. supports_gradient_checkpointing = True
  325. _no_split_modules = ["VideoMAEEmbeddings", "VideoMAELayer"]
  326. _supports_sdpa = True
  327. _supports_flash_attn = True
  328. _supports_flex_attn = True
  329. _supports_attention_backend = True
  330. _can_record_outputs = {
  331. "hidden_states": VideoMAELayer,
  332. "attentions": VideoMAESelfAttention,
  333. }
  334. def _init_weights(self, module):
  335. """Initialize the weights"""
  336. if isinstance(module, (nn.Linear, nn.Conv3d)):
  337. # Slightly different from the TF version which uses truncated_normal for initialization
  338. # cf https://github.com/pytorch/pytorch/pull/5617
  339. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  340. if module.bias is not None:
  341. module.bias.data.zero_()
  342. elif isinstance(module, nn.LayerNorm):
  343. module.bias.data.zero_()
  344. module.weight.data.fill_(1.0)
  345. @auto_docstring
  346. class VideoMAEModel(VideoMAEPreTrainedModel):
  347. def __init__(self, config):
  348. super().__init__(config)
  349. self.config = config
  350. self.embeddings = VideoMAEEmbeddings(config)
  351. self.encoder = VideoMAEEncoder(config)
  352. if config.use_mean_pooling:
  353. self.layernorm = None
  354. else:
  355. self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  356. # Initialize weights and apply final processing
  357. self.post_init()
  358. def get_input_embeddings(self):
  359. return self.embeddings.patch_embeddings
  360. def _prune_heads(self, heads_to_prune):
  361. """
  362. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  363. class PreTrainedModel
  364. """
  365. for layer, heads in heads_to_prune.items():
  366. self.encoder.layer[layer].attention.prune_heads(heads)
  367. @check_model_inputs(tie_last_hidden_states=False)
  368. @auto_docstring
  369. def forward(
  370. self,
  371. pixel_values: torch.FloatTensor,
  372. bool_masked_pos: Optional[torch.BoolTensor] = None,
  373. head_mask: Optional[torch.Tensor] = None,
  374. **kwargs: Unpack[TransformersKwargs],
  375. ) -> BaseModelOutput:
  376. r"""
  377. bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
  378. Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). Each video in the
  379. batch must have the same number of masked patches. If `None`, then all patches are considered. Sequence
  380. length is `(num_frames // tubelet_size) * (image_size // patch_size) ** 2`.
  381. Examples:
  382. ```python
  383. >>> import av
  384. >>> import numpy as np
  385. >>> from transformers import AutoImageProcessor, VideoMAEModel
  386. >>> from huggingface_hub import hf_hub_download
  387. >>> np.random.seed(0)
  388. >>> def read_video_pyav(container, indices):
  389. ... '''
  390. ... Decode the video with PyAV decoder.
  391. ... Args:
  392. ... container (`av.container.input.InputContainer`): PyAV container.
  393. ... indices (`list[int]`): List of frame indices to decode.
  394. ... Returns:
  395. ... result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
  396. ... '''
  397. ... frames = []
  398. ... container.seek(0)
  399. ... start_index = indices[0]
  400. ... end_index = indices[-1]
  401. ... for i, frame in enumerate(container.decode(video=0)):
  402. ... if i > end_index:
  403. ... break
  404. ... if i >= start_index and i in indices:
  405. ... frames.append(frame)
  406. ... return np.stack([x.to_ndarray(format="rgb24") for x in frames])
  407. >>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
  408. ... '''
  409. ... Sample a given number of frame indices from the video.
  410. ... Args:
  411. ... clip_len (`int`): Total number of frames to sample.
  412. ... frame_sample_rate (`int`): Sample every n-th frame.
  413. ... seg_len (`int`): Maximum allowed index of sample's last frame.
  414. ... Returns:
  415. ... indices (`list[int]`): List of sampled frame indices
  416. ... '''
  417. ... converted_len = int(clip_len * frame_sample_rate)
  418. ... end_idx = np.random.randint(converted_len, seg_len)
  419. ... start_idx = end_idx - converted_len
  420. ... indices = np.linspace(start_idx, end_idx, num=clip_len)
  421. ... indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
  422. ... return indices
  423. >>> # video clip consists of 300 frames (10 seconds at 30 FPS)
  424. >>> file_path = hf_hub_download(
  425. ... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset"
  426. ... )
  427. >>> container = av.open(file_path)
  428. >>> # sample 16 frames
  429. >>> indices = sample_frame_indices(clip_len=16, frame_sample_rate=1, seg_len=container.streams.video[0].frames)
  430. >>> video = read_video_pyav(container, indices)
  431. >>> image_processor = AutoImageProcessor.from_pretrained("MCG-NJU/videomae-base")
  432. >>> model = VideoMAEModel.from_pretrained("MCG-NJU/videomae-base")
  433. >>> # prepare video for the model
  434. >>> inputs = image_processor(list(video), return_tensors="pt")
  435. >>> # forward pass
  436. >>> outputs = model(**inputs)
  437. >>> last_hidden_states = outputs.last_hidden_state
  438. >>> list(last_hidden_states.shape)
  439. [1, 1568, 768]
  440. ```"""
  441. # Prepare head mask if needed
  442. # 1.0 in head_mask indicate we keep the head
  443. # attention_probs has shape bsz x n_heads x N x N
  444. # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
  445. # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
  446. head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
  447. embedding_output = self.embeddings(pixel_values, bool_masked_pos)
  448. encoder_outputs: BaseModelOutput = self.encoder(embedding_output, head_mask=head_mask)
  449. sequence_output = encoder_outputs.last_hidden_state
  450. if self.layernorm is not None:
  451. sequence_output = self.layernorm(sequence_output)
  452. return BaseModelOutput(last_hidden_state=sequence_output)
  453. class VideoMAEDecoder(nn.Module):
  454. def __init__(self, config: VideoMAEConfig):
  455. super().__init__()
  456. decoder_num_labels = config.num_channels * config.tubelet_size * config.patch_size**2
  457. decoder_config = deepcopy(config)
  458. decoder_config.hidden_size = config.decoder_hidden_size
  459. decoder_config.num_hidden_layers = config.decoder_num_hidden_layers
  460. decoder_config.num_attention_heads = config.decoder_num_attention_heads
  461. decoder_config.intermediate_size = config.decoder_intermediate_size
  462. self.decoder_layers = nn.ModuleList(
  463. [VideoMAELayer(decoder_config) for _ in range(config.decoder_num_hidden_layers)]
  464. )
  465. self.norm = nn.LayerNorm(config.decoder_hidden_size)
  466. self.head = (
  467. nn.Linear(config.decoder_hidden_size, decoder_num_labels) if decoder_num_labels > 0 else nn.Identity()
  468. )
  469. self.gradient_checkpointing = False
  470. self.config = decoder_config
  471. def forward(self, hidden_states: torch.Tensor, return_token_num: int):
  472. # Apply transformer layers
  473. for layer_module in self.decoder_layers:
  474. hidden_states = layer_module(hidden_states, head_mask=None)
  475. if return_token_num > 0:
  476. hidden_states = hidden_states[:, -return_token_num:]
  477. # predictor projection
  478. hidden_states = self.norm(hidden_states)
  479. logits = self.head(hidden_states)
  480. return VideoMAEDecoderOutput(logits=logits)
  481. @auto_docstring(
  482. custom_intro="""
  483. The VideoMAE Model transformer with the decoder on top for self-supervised pre-training.
  484. """
  485. )
  486. class VideoMAEForPreTraining(VideoMAEPreTrainedModel):
  487. def __init__(self, config):
  488. super().__init__(config)
  489. self.config = config
  490. self.videomae = VideoMAEModel(config)
  491. self.encoder_to_decoder = nn.Linear(config.hidden_size, config.decoder_hidden_size, bias=False)
  492. self.mask_token = nn.Parameter(torch.zeros(1, 1, config.decoder_hidden_size))
  493. self.position_embeddings = get_sinusoid_encoding_table(
  494. self.videomae.embeddings.num_patches, config.decoder_hidden_size
  495. )
  496. self.decoder = VideoMAEDecoder(config)
  497. # Initialize weights and apply final processing
  498. self.post_init()
  499. @can_return_tuple
  500. @auto_docstring
  501. def forward(
  502. self,
  503. pixel_values: torch.FloatTensor,
  504. bool_masked_pos: torch.BoolTensor,
  505. head_mask: Optional[torch.Tensor] = None,
  506. **kwargs: Unpack[TransformersKwargs],
  507. ) -> VideoMAEForPreTrainingOutput:
  508. r"""
  509. bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, sequence_length)`):
  510. Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). Each video in the
  511. batch must have the same number of masked patches. Sequence length is `(num_frames // tubelet_size) *
  512. (image_size // patch_size) ** 2`.
  513. Examples:
  514. ```python
  515. >>> from transformers import AutoImageProcessor, VideoMAEForPreTraining
  516. >>> import numpy as np
  517. >>> import torch
  518. >>> num_frames = 16
  519. >>> video = list(np.random.randint(0, 256, (num_frames, 3, 224, 224)))
  520. >>> image_processor = AutoImageProcessor.from_pretrained("MCG-NJU/videomae-base")
  521. >>> model = VideoMAEForPreTraining.from_pretrained("MCG-NJU/videomae-base")
  522. >>> pixel_values = image_processor(video, return_tensors="pt").pixel_values
  523. >>> num_patches_per_frame = (model.config.image_size // model.config.patch_size) ** 2
  524. >>> seq_length = (num_frames // model.config.tubelet_size) * num_patches_per_frame
  525. >>> bool_masked_pos = torch.randint(0, 2, (1, seq_length)).bool()
  526. >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
  527. >>> loss = outputs.loss
  528. ```"""
  529. outputs: BaseModelOutput = self.videomae(
  530. pixel_values, bool_masked_pos=bool_masked_pos, head_mask=head_mask, **kwargs
  531. )
  532. sequence_output = outputs.last_hidden_state
  533. sequence_output = self.encoder_to_decoder(sequence_output)
  534. # [batch_size, num_visible_patches, decoder_hidden_size]
  535. batch_size, _, num_channels = sequence_output.shape
  536. # we don't unshuffle the correct visible token order, but shuffle the position embeddings accordingly.
  537. if bool_masked_pos is None:
  538. raise ValueError("One must provided a boolean mask ")
  539. expanded_position_embeddings = self.position_embeddings.expand(batch_size, -1, -1).type_as(pixel_values)
  540. expanded_position_embeddings = expanded_position_embeddings.detach().to(device=pixel_values.device, copy=True)
  541. pos_emb_visible = expanded_position_embeddings[~bool_masked_pos].reshape(batch_size, -1, num_channels)
  542. pos_emb_mask = expanded_position_embeddings[bool_masked_pos].reshape(batch_size, -1, num_channels)
  543. # [batch_size, num_patches, decoder_hidden_size]
  544. x_full = torch.cat([sequence_output + pos_emb_visible, self.mask_token + pos_emb_mask], dim=1)
  545. # [batch_size, num_masked_patches, num_channels * patch_size * patch_size]
  546. decoder_outputs: VideoMAEDecoderOutput = self.decoder(x_full, pos_emb_mask.shape[1])
  547. logits = decoder_outputs.logits
  548. loss = None
  549. with torch.no_grad():
  550. # calculate the labels to be predicted
  551. if self.config.num_channels != 3:
  552. # Can't unnormalize with default means/stds
  553. frames = pixel_values
  554. else:
  555. # first, unnormalize the frames
  556. device = pixel_values.device
  557. dtype = pixel_values.dtype
  558. mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device=device, dtype=dtype)[None, None, :, None, None]
  559. std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device=device, dtype=dtype)[None, None, :, None, None]
  560. frames = pixel_values * std + mean # in [0, 1]
  561. batch_size, time, num_channels, height, width = frames.shape
  562. tubelet_size, patch_size = self.config.tubelet_size, self.config.patch_size
  563. if self.config.norm_pix_loss:
  564. # step 1: split up dimensions (time by tubelet_size, height by patch_size, width by patch_size)
  565. frames = frames.view(
  566. batch_size,
  567. time // tubelet_size,
  568. tubelet_size,
  569. num_channels,
  570. height // patch_size,
  571. patch_size,
  572. width // patch_size,
  573. patch_size,
  574. )
  575. # step 2: move dimensions to concatenate:
  576. frames = frames.permute(0, 1, 4, 6, 2, 5, 7, 3).contiguous()
  577. # step 3: concatenate:
  578. frames = frames.view(
  579. batch_size,
  580. time // tubelet_size * height // patch_size * width // patch_size,
  581. tubelet_size * patch_size * patch_size,
  582. num_channels,
  583. )
  584. # step 4: normalize. The authors find that the mean is about 0.48 and standard deviation is about 0.08.
  585. frames_norm = (frames - frames.mean(dim=-2, keepdim=True)) / (
  586. frames.var(dim=-2, unbiased=True, keepdim=True).sqrt() + 1e-6
  587. )
  588. # step 5: reshape to (batch_size, T//ts * H//ps * W//ps, ts * ps * ps * C)
  589. videos_patch = frames_norm.view(
  590. batch_size,
  591. time // tubelet_size * height // patch_size * width // patch_size,
  592. tubelet_size * patch_size * patch_size * num_channels,
  593. )
  594. else:
  595. if self.config.num_channels != 3:
  596. raise ValueError(
  597. "Can't unnormalize non-RGB images. Consider setting config.norm_pix_loss to False."
  598. )
  599. # step 1: split up dimensions (time by tubelet_size, height by patch_size, width by patch_size)
  600. frames = frames.view(
  601. batch_size,
  602. time // tubelet_size,
  603. tubelet_size,
  604. num_channels,
  605. height // patch_size,
  606. patch_size,
  607. width // patch_size,
  608. patch_size,
  609. )
  610. # step 2: move dimensions to concatenate: (batch_size, T//ts, H//ps, W//ps, ts, ps, ps, C)
  611. frames = frames.permute(0, 1, 4, 6, 2, 5, 7, 3).contiguous()
  612. # step 3: concatenate
  613. videos_patch = frames.view(
  614. batch_size,
  615. time // tubelet_size * height // patch_size * width // patch_size,
  616. tubelet_size * patch_size * patch_size * num_channels,
  617. )
  618. batch_size, _, num_channels = videos_patch.shape
  619. labels = videos_patch[bool_masked_pos].reshape(batch_size, -1, num_channels)
  620. loss_fct = MSELoss()
  621. loss = loss_fct(logits, labels)
  622. return VideoMAEForPreTrainingOutput(
  623. loss=loss,
  624. logits=logits,
  625. hidden_states=outputs.hidden_states,
  626. attentions=outputs.attentions,
  627. )
  628. @auto_docstring(
  629. custom_intro="""
  630. VideoMAE Model transformer with a video classification head on top (a linear layer on top of the average pooled hidden
  631. states of all tokens) e.g. for ImageNet.
  632. """
  633. )
  634. class VideoMAEForVideoClassification(VideoMAEPreTrainedModel):
  635. def __init__(self, config):
  636. super().__init__(config)
  637. self.num_labels = config.num_labels
  638. self.videomae = VideoMAEModel(config)
  639. # Classifier head
  640. self.fc_norm = nn.LayerNorm(config.hidden_size) if config.use_mean_pooling else None
  641. self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
  642. # Initialize weights and apply final processing
  643. self.post_init()
  644. @can_return_tuple
  645. @auto_docstring
  646. def forward(
  647. self,
  648. pixel_values: Optional[torch.Tensor] = None,
  649. head_mask: Optional[torch.Tensor] = None,
  650. labels: Optional[torch.Tensor] = None,
  651. **kwargs: Unpack[TransformersKwargs],
  652. ) -> ImageClassifierOutput:
  653. r"""
  654. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  655. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  656. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  657. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  658. Examples:
  659. ```python
  660. >>> import av
  661. >>> import torch
  662. >>> import numpy as np
  663. >>> from transformers import AutoImageProcessor, VideoMAEForVideoClassification
  664. >>> from huggingface_hub import hf_hub_download
  665. >>> np.random.seed(0)
  666. >>> def read_video_pyav(container, indices):
  667. ... '''
  668. ... Decode the video with PyAV decoder.
  669. ... Args:
  670. ... container (`av.container.input.InputContainer`): PyAV container.
  671. ... indices (`list[int]`): List of frame indices to decode.
  672. ... Returns:
  673. ... result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
  674. ... '''
  675. ... frames = []
  676. ... container.seek(0)
  677. ... start_index = indices[0]
  678. ... end_index = indices[-1]
  679. ... for i, frame in enumerate(container.decode(video=0)):
  680. ... if i > end_index:
  681. ... break
  682. ... if i >= start_index and i in indices:
  683. ... frames.append(frame)
  684. ... return np.stack([x.to_ndarray(format="rgb24") for x in frames])
  685. >>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
  686. ... '''
  687. ... Sample a given number of frame indices from the video.
  688. ... Args:
  689. ... clip_len (`int`): Total number of frames to sample.
  690. ... frame_sample_rate (`int`): Sample every n-th frame.
  691. ... seg_len (`int`): Maximum allowed index of sample's last frame.
  692. ... Returns:
  693. ... indices (`list[int]`): List of sampled frame indices
  694. ... '''
  695. ... converted_len = int(clip_len * frame_sample_rate)
  696. ... end_idx = np.random.randint(converted_len, seg_len)
  697. ... start_idx = end_idx - converted_len
  698. ... indices = np.linspace(start_idx, end_idx, num=clip_len)
  699. ... indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
  700. ... return indices
  701. >>> # video clip consists of 300 frames (10 seconds at 30 FPS)
  702. >>> file_path = hf_hub_download(
  703. ... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset"
  704. ... )
  705. >>> container = av.open(file_path)
  706. >>> # sample 16 frames
  707. >>> indices = sample_frame_indices(clip_len=16, frame_sample_rate=1, seg_len=container.streams.video[0].frames)
  708. >>> video = read_video_pyav(container, indices)
  709. >>> image_processor = AutoImageProcessor.from_pretrained("MCG-NJU/videomae-base-finetuned-kinetics")
  710. >>> model = VideoMAEForVideoClassification.from_pretrained("MCG-NJU/videomae-base-finetuned-kinetics")
  711. >>> inputs = image_processor(list(video), return_tensors="pt")
  712. >>> with torch.no_grad():
  713. ... outputs = model(**inputs)
  714. ... logits = outputs.logits
  715. >>> # model predicts one of the 400 Kinetics-400 classes
  716. >>> predicted_label = logits.argmax(-1).item()
  717. >>> print(model.config.id2label[predicted_label])
  718. eating spaghetti
  719. ```"""
  720. outputs: BaseModelOutput = self.videomae(pixel_values, head_mask=head_mask, **kwargs)
  721. sequence_output = outputs.last_hidden_state
  722. if self.fc_norm is not None:
  723. output = sequence_output.mean(1)
  724. output = self.fc_norm(output)
  725. else:
  726. output = sequence_output[:, 0]
  727. logits = self.classifier(output)
  728. loss = None
  729. if labels is not None:
  730. loss = self.loss_function(labels, logits, self.config, **kwargs)
  731. return ImageClassifierOutput(
  732. loss=loss,
  733. logits=logits,
  734. hidden_states=outputs.hidden_states,
  735. attentions=outputs.attentions,
  736. )
  737. __all__ = ["VideoMAEForPreTraining", "VideoMAEModel", "VideoMAEPreTrainedModel", "VideoMAEForVideoClassification"]