modeling_vjepa2.py 48 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218
  1. # coding=utf-8
  2. # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from dataclasses import dataclass
  16. from typing import Callable, Optional, Union
  17. import torch
  18. from torch import nn
  19. from ...activations import ACT2FN
  20. from ...modeling_layers import GradientCheckpointingLayer
  21. from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput
  22. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  23. from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging
  24. from .configuration_vjepa2 import VJEPA2Config
  25. logger = logging.get_logger(__name__)
  26. @dataclass
  27. @auto_docstring(
  28. custom_intro="""
  29. VJEPA Predictor outputs that also contains the masked encoder outputs
  30. """
  31. )
  32. class VJEPA2WithMaskedInputPredictorOutput(ModelOutput):
  33. r"""
  34. masked_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*, returned when `context_mask` is provided which is applied on VJEPA2Encoder outputs):
  35. The masked hidden state of the model.
  36. target_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*, returned when `target_mask` is provided which is applied on VJEPA2Encoder outputs):
  37. The target hidden state of the model.
  38. """
  39. last_hidden_state: torch.FloatTensor
  40. masked_hidden_state: Optional[torch.FloatTensor] = None
  41. hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  42. attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  43. target_hidden_state: Optional[torch.FloatTensor] = None
  44. @dataclass
  45. @auto_docstring(
  46. custom_intro="""
  47. VJEPA outputs that also contains the masked encoder outputs
  48. Optionally contains the predictor outputs
  49. """
  50. )
  51. class VJEPA2WithMaskedInputModelOutput(ModelOutput):
  52. r"""
  53. masked_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*, returned when `context_mask` is provided which is applied on VJEPA2Encoder outputs):
  54. The masked hidden state of the model.
  55. predictor_output (`VJEPA2WithMaskedInputPredictorOutput`, *optional*):
  56. The output from the Predictor module.
  57. """
  58. last_hidden_state: torch.FloatTensor
  59. masked_hidden_state: Optional[torch.FloatTensor] = None
  60. hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  61. attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  62. predictor_output: Optional[VJEPA2WithMaskedInputPredictorOutput] = None
  63. def to_tuple(self):
  64. output = list(super().to_tuple())
  65. if isinstance(output[-1], VJEPA2WithMaskedInputPredictorOutput):
  66. output[-1] = output[-1].to_tuple()
  67. return tuple(output)
  68. class VJEPA2PatchEmbeddings3D(nn.Module):
  69. """
  70. Image to Patch Embedding
  71. """
  72. def __init__(
  73. self,
  74. config: VJEPA2Config,
  75. hidden_size: int = 1024,
  76. ):
  77. super().__init__()
  78. self.patch_size = config.patch_size
  79. self.tubelet_size = config.tubelet_size
  80. self.hidden_size = hidden_size
  81. self.proj = nn.Conv3d(
  82. in_channels=config.in_chans,
  83. out_channels=hidden_size,
  84. kernel_size=(config.tubelet_size, config.patch_size, config.patch_size),
  85. stride=(config.tubelet_size, config.patch_size, config.patch_size),
  86. )
  87. @staticmethod
  88. def num_patches(config):
  89. return (
  90. (config.frames_per_clip // config.tubelet_size)
  91. * (config.crop_size // config.patch_size)
  92. * (config.crop_size // config.patch_size)
  93. )
  94. def forward(self, pixel_values_videos: torch.Tensor) -> torch.Tensor:
  95. x = self.proj(pixel_values_videos).flatten(2).transpose(1, 2)
  96. return x
  97. class VJEPA2Embeddings(nn.Module):
  98. """
  99. Construct mask token, position and patch embeddings.
  100. """
  101. def __init__(self, config: VJEPA2Config, hidden_size: int = 1024):
  102. super().__init__()
  103. self.config = config
  104. self.hidden_size = hidden_size
  105. self.patch_embeddings = VJEPA2PatchEmbeddings3D(config, hidden_size=hidden_size)
  106. self.num_patches = self.patch_embeddings.num_patches
  107. self.patch_size = config.patch_size
  108. def forward(self, pixel_values_videos: torch.Tensor) -> torch.Tensor:
  109. num_frames = pixel_values_videos.shape[1]
  110. # Swap `frames` and `channels` dims, the result is:
  111. # (batch_size, channels, num_frames, height, width)
  112. pixel_values_videos = pixel_values_videos.permute(0, 2, 1, 3, 4)
  113. # For some cases, if the input vision (image/video) consists of num_frames < tubelet_size,
  114. # then embedding lookup fails. In these cases, we duplicate the frames.
  115. if num_frames < self.config.tubelet_size:
  116. pixel_values_videos = pixel_values_videos.repeat(1, 1, self.config.tubelet_size, 1, 1)
  117. target_dtype = self.patch_embeddings.proj.weight.dtype
  118. pixel_values_videos = pixel_values_videos.to(dtype=target_dtype)
  119. embeddings = self.patch_embeddings(pixel_values_videos)
  120. return embeddings
  121. # Adapted from transformers.models.vit.modeling_vit.eager_attention_forward
  122. def eager_attention_forward(
  123. module: nn.Module,
  124. query: torch.Tensor,
  125. key: torch.Tensor,
  126. value: torch.Tensor,
  127. attention_mask: Optional[torch.Tensor],
  128. scaling: float,
  129. dropout: float = 0.0,
  130. **kwargs,
  131. ):
  132. # Take the dot product between "query" and "key" to get the raw attention scores.
  133. attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
  134. # Normalize the attention scores to probabilities.
  135. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  136. # This is actually dropping out entire tokens to attend to, which might
  137. # seem a bit unusual, but is taken from the original Transformer paper.
  138. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  139. # Mask heads if we want to
  140. if attention_mask is not None:
  141. attn_weights = attn_weights * attention_mask
  142. attn_output = torch.matmul(attn_weights, value)
  143. attn_output = attn_output.transpose(1, 2).contiguous()
  144. return attn_output, attn_weights
  145. def rotate_queries_or_keys(x, pos):
  146. B, num_heads, N, D = x.size()
  147. # similar to inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
  148. # they are computing this every time. instead HF style is to compute the inv_freq once and store it
  149. # -- compute angle for each position
  150. omega = torch.arange(D // 2, dtype=x.dtype, device=x.device)
  151. omega /= D / 2.0
  152. omega = 1.0 / 10000**omega # (D/2,)
  153. freq = pos.unsqueeze(-1) * omega # (..., N, D/2), outer product
  154. # -- build rotation matrix and apply
  155. emb_sin = freq.sin() # (..., N, D/2)
  156. emb_cos = freq.cos() # (..., N, D/2)
  157. emb_sin = emb_sin.squeeze(-1).repeat(1, 1, 1, 2)
  158. emb_cos = emb_cos.squeeze(-1).repeat(1, 1, 1, 2)
  159. # --
  160. y = x.unflatten(-1, (-1, 2))
  161. y1, y2 = y.unbind(dim=-1)
  162. y = torch.stack((-y2, y1), dim=-1)
  163. y = y.flatten(-2)
  164. return (x * emb_cos) + (y * emb_sin)
  165. class VJEPA2RopeAttention(nn.Module):
  166. def __init__(
  167. self,
  168. config: VJEPA2Config,
  169. hidden_size: int = 1024,
  170. num_attention_heads: int = 16,
  171. ):
  172. super().__init__()
  173. self.config = config
  174. self.hidden_size = hidden_size
  175. self.num_attention_heads = num_attention_heads
  176. if hidden_size % num_attention_heads != 0:
  177. raise ValueError(
  178. f"The hidden size {(hidden_size,)} is not a multiple of the number of attention "
  179. f"heads {num_attention_heads}."
  180. )
  181. self.attention_head_size = int(hidden_size / num_attention_heads)
  182. self.all_head_size = self.num_attention_heads * self.attention_head_size
  183. self.query = nn.Linear(hidden_size, self.all_head_size, bias=config.qkv_bias)
  184. self.key = nn.Linear(hidden_size, self.all_head_size, bias=config.qkv_bias)
  185. self.value = nn.Linear(hidden_size, self.all_head_size, bias=config.qkv_bias)
  186. self.proj = nn.Linear(hidden_size, hidden_size)
  187. self.dropout_prob = config.attention_probs_dropout_prob
  188. self.dropout = nn.Dropout(self.dropout_prob)
  189. self.grid_size = self.config.crop_size // self.config.patch_size
  190. self.grid_depth = self.config.frames_per_clip // self.config.tubelet_size
  191. self.d_dim = int(2 * ((self.attention_head_size // 3) // 2))
  192. self.h_dim = int(2 * ((self.attention_head_size // 3) // 2))
  193. self.w_dim = int(2 * ((self.attention_head_size // 3) // 2))
  194. self.scaling = self.attention_head_size**-0.5
  195. self.is_causal = False
  196. def _get_frame_pos(self, ids):
  197. tokens_per_frame = int(self.grid_size * self.grid_size)
  198. return ids // tokens_per_frame
  199. def _get_height_pos(self, ids):
  200. # Remove frame component from ids
  201. tokens_per_frame = int(self.grid_size * self.grid_size)
  202. frame_ids = self._get_frame_pos(ids)
  203. ids = ids - tokens_per_frame * frame_ids
  204. # --
  205. tokens_per_row = self.grid_size
  206. return ids // tokens_per_row
  207. def get_position_ids(self, x, masks=None):
  208. device = x.device
  209. token_size = x.size(1)
  210. # Note: when masks is none, we use a 1d id instead of Bxnum_attention_heads mask,
  211. # as 1d vector is broadcasted to the correct shapes.
  212. if masks is not None:
  213. ids = masks.unsqueeze(1).repeat(1, self.num_attention_heads, 1)
  214. else:
  215. ids = torch.arange(token_size, device=device)
  216. # change to allow for extrapolation
  217. tokens_per_frame = int(self.grid_size * self.grid_size)
  218. frame_ids = self._get_frame_pos(ids)
  219. # --
  220. tokens_per_row = self.grid_size
  221. height_ids = self._get_height_pos(ids)
  222. # --
  223. # Remove frame component from ids (1st term) and height component (2nd term)
  224. width_ids = (ids - tokens_per_frame * frame_ids) - tokens_per_row * height_ids
  225. return frame_ids, height_ids, width_ids
  226. def apply_rotary_embeddings(self, qk, pos_ids):
  227. d_mask, h_mask, w_mask = pos_ids
  228. s = 0
  229. qkd = rotate_queries_or_keys(qk[..., s : s + self.d_dim], pos=d_mask)
  230. s += self.d_dim
  231. qkh = rotate_queries_or_keys(qk[..., s : s + self.h_dim], pos=h_mask)
  232. s += self.h_dim
  233. qkw = rotate_queries_or_keys(qk[..., s : s + self.w_dim], pos=w_mask)
  234. s += self.w_dim
  235. # Combine rotated dimension
  236. if s < self.attention_head_size:
  237. qkr = qk[..., s:]
  238. qk = torch.cat([qkd, qkh, qkw, qkr], dim=-1)
  239. else:
  240. qk = torch.cat([qkd, qkh, qkw], dim=-1)
  241. return qk
  242. def forward(
  243. self,
  244. hidden_states,
  245. position_mask: Optional[torch.Tensor] = None,
  246. output_attentions: bool = False,
  247. head_mask: Optional[torch.Tensor] = None,
  248. ) -> Union[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor]]:
  249. batch_size, seq_length, _ = hidden_states.shape
  250. query_layer = (
  251. self.query(hidden_states)
  252. .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
  253. .transpose(1, 2)
  254. )
  255. key_layer = (
  256. self.key(hidden_states)
  257. .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
  258. .transpose(1, 2)
  259. )
  260. value_layer = (
  261. self.value(hidden_states)
  262. .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
  263. .transpose(1, 2)
  264. )
  265. pos_ids = self.get_position_ids(hidden_states, masks=position_mask)
  266. key_layer = self.apply_rotary_embeddings(key_layer, pos_ids)
  267. query_layer = self.apply_rotary_embeddings(query_layer, pos_ids)
  268. attention_interface: Callable = eager_attention_forward
  269. if self.config._attn_implementation != "eager":
  270. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  271. context_layer, attention_probs = attention_interface(
  272. self,
  273. query_layer,
  274. key_layer,
  275. value_layer,
  276. head_mask,
  277. is_causal=self.is_causal,
  278. scaling=self.scaling,
  279. dropout=0.0 if not self.training else self.dropout_prob,
  280. )
  281. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  282. context_layer = self.proj(context_layer.reshape(new_context_layer_shape))
  283. outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
  284. return outputs
  285. # Adapted from transformers.models.beit.modeling_dinov2.drop_path
  286. def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
  287. """
  288. Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  289. Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
  290. however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
  291. See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
  292. layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
  293. argument.
  294. """
  295. if drop_prob == 0.0 or not training:
  296. return input
  297. keep_prob = 1 - drop_prob
  298. shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
  299. random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
  300. random_tensor.floor_() # binarize
  301. output = input.div(keep_prob) * random_tensor
  302. return output
  303. # Adapted from transformers.models.beit.modeling_beit.BeitDropPath
  304. class VJEPA2DropPath(nn.Module):
  305. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
  306. def __init__(self, drop_prob: Optional[float] = None):
  307. super().__init__()
  308. self.drop_prob = drop_prob
  309. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  310. return drop_path(hidden_states, self.drop_prob, self.training)
  311. def extra_repr(self) -> str:
  312. return f"p={self.drop_prob}"
  313. class VJEPA2MLP(nn.Module):
  314. def __init__(self, config: VJEPA2Config, hidden_size: int = 1024, mlp_ratio: float = 4.0):
  315. super().__init__()
  316. in_features = out_features = hidden_size
  317. hidden_features = int(hidden_size * mlp_ratio)
  318. self.fc1 = nn.Linear(in_features, hidden_features, bias=True)
  319. self.activation = ACT2FN[config.hidden_act]
  320. self.fc2 = nn.Linear(hidden_features, out_features, bias=True)
  321. def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
  322. hidden_state = self.fc1(hidden_state)
  323. hidden_state = self.activation(hidden_state)
  324. hidden_state = self.fc2(hidden_state)
  325. return hidden_state
  326. class VJEPA2Layer(GradientCheckpointingLayer):
  327. """This corresponds to the Block class in the original implementation."""
  328. def __init__(
  329. self,
  330. config: VJEPA2Config,
  331. drop_path_rate: float = 0.0,
  332. hidden_size: int = 1024,
  333. num_attention_heads: int = 16,
  334. mlp_ratio: float = 4.0,
  335. ):
  336. super().__init__()
  337. self.config = config
  338. self.hidden_size = hidden_size
  339. self.num_attention_heads = num_attention_heads
  340. self.mlp_ratio = mlp_ratio
  341. self.norm1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
  342. self.attention = VJEPA2RopeAttention(config, hidden_size, num_attention_heads)
  343. self.drop_path = VJEPA2DropPath(drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
  344. self.norm2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
  345. self.mlp = VJEPA2MLP(config, hidden_size=hidden_size, mlp_ratio=mlp_ratio)
  346. def forward(
  347. self,
  348. hidden_states: torch.Tensor,
  349. position_mask: Optional[torch.Tensor] = None,
  350. head_mask: Optional[torch.Tensor] = None,
  351. output_attentions: bool = False,
  352. ) -> tuple[torch.Tensor, ...]:
  353. # Self-Attention
  354. residual = hidden_states
  355. hidden_states = self.norm1(hidden_states)
  356. self_attention_outputs = self.attention(
  357. hidden_states,
  358. position_mask=position_mask, # position mask for context/target selection
  359. head_mask=head_mask, # head mask is applied at F.scaled_dot_product_attention
  360. output_attentions=output_attentions,
  361. )
  362. attention_output = self_attention_outputs[0]
  363. hidden_states = self.drop_path(attention_output) + residual
  364. # MLP
  365. residual = hidden_states
  366. hidden_states = self.norm2(hidden_states)
  367. hidden_states = self.mlp(hidden_states)
  368. hidden_states = self.drop_path(hidden_states) + residual
  369. # Add self attentions if we output attention weights
  370. outputs = self_attention_outputs[1:]
  371. outputs = (hidden_states,) + outputs
  372. return outputs
  373. class VJEPA2Encoder(nn.Module):
  374. def __init__(self, config: VJEPA2Config):
  375. super().__init__()
  376. self.config = config
  377. self.embeddings = VJEPA2Embeddings(config, hidden_size=config.hidden_size)
  378. drop_path_rates = [
  379. (config.drop_path_rate * i / (config.num_hidden_layers - 1) if config.num_hidden_layers > 1 else 0.0)
  380. for i in range(config.num_hidden_layers)
  381. ]
  382. self.layer = nn.ModuleList(
  383. [
  384. VJEPA2Layer(
  385. config,
  386. drop_path_rate=drop_path_rates[i],
  387. hidden_size=config.hidden_size,
  388. num_attention_heads=config.num_attention_heads,
  389. mlp_ratio=config.mlp_ratio,
  390. )
  391. for i in range(config.num_hidden_layers)
  392. ]
  393. )
  394. self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  395. self.gradient_checkpointing = False
  396. @can_return_tuple
  397. def forward(
  398. self,
  399. pixel_values_videos: Optional[torch.Tensor] = None,
  400. head_mask: Optional[torch.Tensor] = None,
  401. output_attentions: bool = False,
  402. output_hidden_states: bool = False,
  403. **kwargs,
  404. ) -> BaseModelOutput:
  405. all_hidden_states = () if output_hidden_states else None
  406. all_self_attentions = () if output_attentions else None
  407. hidden_states = self.embeddings(pixel_values_videos)
  408. for i, layer_module in enumerate(self.layer):
  409. if output_hidden_states:
  410. all_hidden_states = all_hidden_states + (hidden_states,)
  411. layer_head_mask = head_mask[i] if head_mask is not None else None
  412. layer_outputs = layer_module(hidden_states, None, layer_head_mask, output_attentions)
  413. hidden_states = layer_outputs[0]
  414. if output_attentions:
  415. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  416. hidden_states = self.layernorm(hidden_states)
  417. if output_hidden_states:
  418. all_hidden_states = all_hidden_states + (hidden_states,)
  419. return BaseModelOutput(
  420. last_hidden_state=hidden_states,
  421. hidden_states=all_hidden_states,
  422. attentions=all_self_attentions,
  423. )
  424. def apply_masks(tensor: torch.Tensor, masks: list[torch.Tensor]) -> torch.Tensor:
  425. """
  426. Args:
  427. tensor (`torch.Tensor`):
  428. Tensor of shape [batch_size, num_patches, feature_dim]
  429. masks (`List[torch.Tensor]`):
  430. List of tensors of shape [batch_size, num_patches] containing indices of patches to keep
  431. """
  432. all_masked_tensors = []
  433. for mask in masks:
  434. mask = mask.to(tensor.device)
  435. mask_keep = mask.unsqueeze(-1).repeat(1, 1, tensor.size(-1))
  436. all_masked_tensors += [torch.gather(tensor, dim=1, index=mask_keep)]
  437. return torch.cat(all_masked_tensors, dim=0)
  438. class VJEPA2PredictorEmbeddings(nn.Module):
  439. """
  440. Construct mask token, position and patch embeddings.
  441. """
  442. def __init__(self, config: VJEPA2Config):
  443. super().__init__()
  444. self.config = config
  445. self.predictor_embeddings = nn.Linear(config.hidden_size, config.pred_hidden_size)
  446. self.num_mask_tokens = 0
  447. self.zero_init_mask_tokens = config.pred_zero_init_mask_tokens
  448. self.num_mask_tokens = config.pred_num_mask_tokens
  449. self.mask_tokens = nn.Parameter(torch.zeros(self.num_mask_tokens, 1, 1, config.pred_hidden_size))
  450. self.patch_size = config.patch_size
  451. self.config = config
  452. @staticmethod
  453. def num_patches(config):
  454. if config.frames_per_clip > 1:
  455. return (
  456. (config.frames_per_clip // config.tubelet_size)
  457. * (config.crop_size // config.patch_size)
  458. * (config.crop_size // config.patch_size)
  459. )
  460. else:
  461. return (config.crop_size // config.patch_size) * (config.crop_size // config.patch_size)
  462. def forward(
  463. self,
  464. hidden_states: torch.Tensor,
  465. context_mask: list[torch.Tensor],
  466. target_mask: list[torch.Tensor],
  467. mask_index: int = 1,
  468. ) -> tuple[torch.Tensor, torch.Tensor]:
  469. """
  470. hidden_states : encoder outputs (context)
  471. context_mask: tokens of the context (outputs from the encoder)
  472. target_mask: tokens to predict
  473. mask_index: index of the target mask to choose (useful for multiclip?)
  474. """
  475. B = hidden_states.size(0)
  476. context = self.predictor_embeddings(hidden_states)
  477. # Make target tokens
  478. mask_index = mask_index % self.num_mask_tokens
  479. target = self.mask_tokens[mask_index]
  480. # Note: this is problematic if the config isn't initialized with the right frames_per_clip value,
  481. # e.g. for scenarios if we want to run predictor for more tokens than in the config.
  482. # target = target.repeat(B, self.num_patches(self.config), 1)
  483. # Remedy: use the provided target mask to get the max patch num
  484. max_patch_num = target_mask[0].max() + 1 # one extra to include the last patch
  485. target = target.repeat(B, max_patch_num, 1)
  486. target = apply_masks(target, target_mask)
  487. # Concatenate context & target tokens
  488. context = context.repeat(len(context_mask), 1, 1)
  489. embeddings = torch.cat([context, target], dim=1)
  490. # Positions of context & target tokens
  491. cm = torch.cat(context_mask, dim=0)
  492. tm = torch.cat(target_mask, dim=0)
  493. masks = torch.cat([cm, tm], dim=1)
  494. return embeddings, masks
  495. class VJEPA2Predictor(nn.Module):
  496. def __init__(self, config: VJEPA2Config):
  497. super().__init__()
  498. self.config = config
  499. self.gradient_checkpointing = False
  500. self.embeddings = VJEPA2PredictorEmbeddings(config)
  501. drop_path_rates = [
  502. (
  503. config.drop_path_rate * i / (config.pred_num_hidden_layers - 1)
  504. if config.pred_num_hidden_layers > 1
  505. else 0.0
  506. )
  507. for i in range(config.pred_num_hidden_layers)
  508. ]
  509. self.layer = nn.ModuleList(
  510. [
  511. VJEPA2Layer(
  512. config,
  513. drop_path_rate=drop_path_rates[i],
  514. hidden_size=config.pred_hidden_size,
  515. num_attention_heads=config.pred_num_attention_heads,
  516. mlp_ratio=config.pred_mlp_ratio,
  517. )
  518. for i in range(config.pred_num_hidden_layers)
  519. ]
  520. )
  521. self.layernorm = nn.LayerNorm(config.pred_hidden_size, eps=config.layer_norm_eps)
  522. self.proj = nn.Linear(config.pred_hidden_size, config.hidden_size, bias=True)
  523. def sort_tokens(self, hidden_states, position_masks, argsort, head_mask=None):
  524. # gather position masks
  525. argsort = argsort.to(position_masks.device)
  526. position_masks = torch.gather(position_masks, dim=1, index=argsort)
  527. # gather hidden states
  528. argsort = argsort.to(hidden_states.device)
  529. hidden_states_argsort = argsort.unsqueeze(-1).expand(-1, -1, hidden_states.size(-1))
  530. hidden_states = torch.gather(hidden_states, dim=1, index=hidden_states_argsort)
  531. # gather head mask
  532. if head_mask is not None and head_mask[0] is not None:
  533. argsort = argsort.to(head_mask.device)
  534. head_mask = head_mask.permute(1, 0, 2, 3, 4)
  535. argsort_4d = (
  536. argsort.unsqueeze(1)
  537. .unsqueeze(1)
  538. .expand(-1, head_mask.size(1), head_mask.size(2), -1)
  539. .unsqueeze(-1)
  540. .expand(-1, -1, -1, -1, head_mask.size(-1))
  541. )
  542. head_mask = torch.gather(head_mask, dim=3, index=argsort_4d)
  543. argsort_5d = (
  544. argsort.unsqueeze(1)
  545. .unsqueeze(1)
  546. .unsqueeze(1)
  547. .expand(-1, head_mask.size(1), head_mask.size(2), head_mask.size(3), -1)
  548. )
  549. head_mask = torch.gather(head_mask, dim=4, index=argsort_5d)
  550. head_mask = head_mask.permute(1, 0, 2, 3, 4)
  551. return hidden_states, position_masks, head_mask
  552. def unsort_tokens(self, hidden_states, argsort):
  553. argsort = argsort.to(hidden_states.device)
  554. reverse_argsort = torch.argsort(argsort, dim=1)
  555. reverse_argsort = reverse_argsort.unsqueeze(-1).expand(-1, -1, hidden_states.size(-1))
  556. hidden_states = torch.gather(hidden_states, dim=1, index=reverse_argsort)
  557. return hidden_states
  558. @can_return_tuple
  559. def forward(
  560. self,
  561. encoder_hidden_states: torch.Tensor,
  562. context_mask: list[torch.Tensor],
  563. target_mask: list[torch.Tensor],
  564. head_mask: Optional[torch.Tensor] = None,
  565. output_attentions: bool = False,
  566. output_hidden_states: bool = False,
  567. **kwargs,
  568. ) -> BaseModelOutput:
  569. all_hidden_states = () if output_hidden_states else None
  570. all_self_attentions = () if output_attentions else None
  571. # mask out the encoder hidden states
  572. # this is implemented here as in VJEPA training a separate encoder is used for target
  573. encoder_hidden_states = apply_masks(encoder_hidden_states, context_mask)
  574. _, N_ctxt, D = encoder_hidden_states.shape
  575. hidden_states, position_masks = self.embeddings(encoder_hidden_states, context_mask, target_mask)
  576. # Put tokens in sorted order
  577. argsort = torch.argsort(position_masks, dim=1) # [B, N]
  578. hidden_states, position_masks, head_mask = self.sort_tokens(hidden_states, position_masks, argsort, head_mask)
  579. for i, layer_module in enumerate(self.layer):
  580. if output_hidden_states:
  581. all_hidden_states = all_hidden_states + (hidden_states,)
  582. layer_head_mask = head_mask[i] if head_mask is not None else None
  583. layer_outputs = layer_module(hidden_states, position_masks, layer_head_mask, output_attentions)
  584. hidden_states = layer_outputs[0]
  585. if output_attentions:
  586. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  587. if output_hidden_states:
  588. all_hidden_states = all_hidden_states + (hidden_states,)
  589. hidden_states = self.layernorm(hidden_states)
  590. # unsort and extract the predicted tokens
  591. hidden_states = self.unsort_tokens(hidden_states, argsort)
  592. hidden_states = hidden_states[:, N_ctxt:]
  593. # projection
  594. hidden_states = self.proj(hidden_states)
  595. return BaseModelOutput(
  596. last_hidden_state=hidden_states,
  597. hidden_states=all_hidden_states,
  598. attentions=all_self_attentions,
  599. )
  600. class VJEPA2PoolerSelfAttention(nn.Module):
  601. """Multi-headed attention from 'Attention Is All You Need' paper"""
  602. def __init__(self, config: VJEPA2Config):
  603. super().__init__()
  604. self.config = config
  605. self.embed_dim = config.hidden_size
  606. self.num_heads = config.num_attention_heads
  607. self.head_dim = self.embed_dim // self.num_heads
  608. if self.head_dim * self.num_heads != self.embed_dim:
  609. raise ValueError(
  610. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  611. f" {self.num_heads})."
  612. )
  613. self.scale = self.head_dim**-0.5
  614. self.dropout = config.attention_dropout
  615. self.is_causal = False
  616. self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
  617. self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
  618. self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
  619. self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
  620. def forward(
  621. self,
  622. hidden_states: torch.Tensor,
  623. attention_mask: Optional[torch.Tensor] = None,
  624. output_attentions: Optional[bool] = False,
  625. ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
  626. """Input shape: Batch x Time x Channel"""
  627. batch_size, seq_length, embed_dim = hidden_states.shape
  628. queries = self.q_proj(hidden_states)
  629. keys = self.k_proj(hidden_states)
  630. values = self.v_proj(hidden_states)
  631. queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
  632. keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
  633. values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
  634. attention_interface: Callable = eager_attention_forward
  635. if self.config._attn_implementation != "eager":
  636. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  637. attn_output, attn_weights = attention_interface(
  638. self,
  639. queries,
  640. keys,
  641. values,
  642. attention_mask,
  643. is_causal=self.is_causal,
  644. scaling=self.scale,
  645. dropout=0.0 if not self.training else self.dropout,
  646. )
  647. attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
  648. attn_output = self.out_proj(attn_output)
  649. if not output_attentions:
  650. attn_weights = None
  651. return attn_output, attn_weights
  652. class VJEPA2PoolerCrossAttention(nn.Module):
  653. """It's different from other cross-attention layers, doesn't have output projection layer (o_proj)"""
  654. # in case of modular refactoring - o_proj can be replaces with nn.Identity()
  655. def __init__(self, config: VJEPA2Config):
  656. super().__init__()
  657. self.config = config
  658. self.embed_dim = config.hidden_size
  659. self.num_heads = config.num_attention_heads
  660. self.head_dim = self.embed_dim // self.num_heads
  661. if self.head_dim * self.num_heads != self.embed_dim:
  662. raise ValueError(
  663. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  664. f" {self.num_heads})."
  665. )
  666. self.scale = self.head_dim**-0.5
  667. self.dropout = config.attention_dropout
  668. self.is_causal = False
  669. self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
  670. self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
  671. self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
  672. def forward(
  673. self,
  674. queries: torch.Tensor,
  675. keys: torch.Tensor,
  676. values: torch.Tensor,
  677. attention_mask: Optional[torch.Tensor] = None,
  678. output_attentions: Optional[bool] = False,
  679. ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
  680. """Input shape: Batch x Time x Channel"""
  681. batch_size, q_seq_length, embed_dim = queries.shape
  682. kv_seq_length = keys.shape[1]
  683. queries = self.q_proj(queries)
  684. keys = self.k_proj(keys)
  685. values = self.v_proj(values)
  686. queries = queries.view(batch_size, q_seq_length, self.num_heads, self.head_dim).transpose(1, 2)
  687. keys = keys.view(batch_size, kv_seq_length, self.num_heads, self.head_dim).transpose(1, 2)
  688. values = values.view(batch_size, kv_seq_length, self.num_heads, self.head_dim).transpose(1, 2)
  689. attention_interface: Callable = eager_attention_forward
  690. if self.config._attn_implementation != "eager":
  691. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  692. attn_output, attn_weights = attention_interface(
  693. self,
  694. queries,
  695. keys,
  696. values,
  697. attention_mask,
  698. is_causal=self.is_causal,
  699. scaling=self.scale,
  700. dropout=0.0 if not self.training else self.dropout,
  701. )
  702. attn_output = attn_output.reshape(batch_size, q_seq_length, embed_dim).contiguous()
  703. if not output_attentions:
  704. attn_weights = None
  705. return attn_output, attn_weights
  706. # Modified from SiglipEncoderLayer, but we have to propagate proper hidden_size to VJEPA2MLP
  707. class VJEPA2PoolerSelfAttentionLayer(GradientCheckpointingLayer):
  708. def __init__(self, config: VJEPA2Config):
  709. super().__init__()
  710. self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  711. self.self_attn = VJEPA2PoolerSelfAttention(config)
  712. self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  713. self.mlp = VJEPA2MLP(config, hidden_size=config.hidden_size)
  714. def forward(
  715. self,
  716. hidden_states: torch.Tensor,
  717. attention_mask: torch.Tensor,
  718. output_attentions: Optional[bool] = False,
  719. ) -> tuple[torch.Tensor, ...]:
  720. """
  721. Args:
  722. hidden_states (`torch.FloatTensor`):
  723. Input to the layer of shape `(batch, seq_len, embed_dim)`.
  724. attention_mask (`torch.FloatTensor`):
  725. Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
  726. output_attentions (`bool`, *optional*, defaults to `False`):
  727. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  728. returned tensors for more detail.
  729. """
  730. residual = hidden_states
  731. hidden_states = self.layer_norm1(hidden_states)
  732. hidden_states, attn_weights = self.self_attn(
  733. hidden_states=hidden_states,
  734. attention_mask=attention_mask,
  735. output_attentions=output_attentions,
  736. )
  737. hidden_states = residual + hidden_states
  738. residual = hidden_states
  739. hidden_states = self.layer_norm2(hidden_states)
  740. hidden_states = self.mlp(hidden_states)
  741. hidden_states = residual + hidden_states
  742. outputs = (hidden_states,)
  743. if output_attentions:
  744. outputs += (attn_weights,)
  745. return outputs
  746. class VJEPA2PoolerCrossAttentionLayer(GradientCheckpointingLayer):
  747. def __init__(self, config: VJEPA2Config):
  748. super().__init__()
  749. self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  750. self.cross_attn = VJEPA2PoolerCrossAttention(config)
  751. self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  752. self.mlp = VJEPA2MLP(config, hidden_size=config.hidden_size)
  753. def forward(
  754. self,
  755. queries: torch.Tensor,
  756. hidden_state: torch.Tensor,
  757. attention_mask: Optional[torch.Tensor] = None,
  758. output_attentions: bool = False,
  759. ) -> tuple[torch.Tensor, ...]:
  760. # Apply cross-attention
  761. residual = queries
  762. hidden_state = self.layer_norm1(hidden_state)
  763. hidden_state, *attn_weights = self.cross_attn(
  764. queries,
  765. hidden_state,
  766. hidden_state,
  767. attention_mask=attention_mask,
  768. output_attentions=output_attentions,
  769. )
  770. hidden_state = residual + hidden_state
  771. # Apply MLP
  772. residual = hidden_state
  773. hidden_state = self.layer_norm2(hidden_state)
  774. hidden_state = self.mlp(hidden_state)
  775. hidden_state = residual + hidden_state
  776. outputs = (hidden_state,)
  777. if output_attentions:
  778. outputs += tuple(attn_weights)
  779. return outputs
  780. class VJEPA2AttentivePooler(nn.Module):
  781. """Attentive Pooler"""
  782. def __init__(self, config: VJEPA2Config):
  783. super().__init__()
  784. self.query_tokens = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
  785. self.cross_attention_layer = VJEPA2PoolerCrossAttentionLayer(config)
  786. self.self_attention_layers = nn.ModuleList(
  787. [VJEPA2PoolerSelfAttentionLayer(config) for _ in range(config.num_pooler_layers)]
  788. )
  789. def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
  790. for layer in self.self_attention_layers:
  791. hidden_state = layer(hidden_state, attention_mask=None)[0]
  792. queries = self.query_tokens.repeat(hidden_state.shape[0], 1, 1)
  793. hidden_state = self.cross_attention_layer(queries, hidden_state)[0]
  794. return hidden_state.squeeze(1)
  795. @auto_docstring
  796. class VJEPA2PreTrainedModel(PreTrainedModel):
  797. config: VJEPA2Config
  798. base_model_prefix = "vjepa2"
  799. main_input_name = "pixel_values_videos"
  800. supports_gradient_checkpointing = True
  801. _no_split_modules = [
  802. "VJEPA2Layer",
  803. "VJEPA2PoolerSelfAttentionLayer",
  804. "VJEPA2PoolerCrossAttentionLayer",
  805. "VJEPA2PredictorEmbeddings",
  806. ]
  807. _supports_sdpa = True
  808. _supports_flash_attn = True
  809. def _init_weights(self, module):
  810. """Initialize the weights"""
  811. init_std = self.config.initializer_range
  812. # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
  813. # `trunc_normal_cpu` not implemented in `half` issues
  814. def trunc_normal_f32_(weight, std):
  815. data_float_32 = weight.data.to(torch.float32)
  816. data_init = nn.init.trunc_normal_(data_float_32, mean=0.0, std=std)
  817. weight.data = data_init.to(weight.dtype)
  818. if isinstance(module, VJEPA2AttentivePooler):
  819. trunc_normal_f32_(module.query_tokens, std=init_std)
  820. for i, layer in enumerate(module.self_attention_layers, 1):
  821. std = init_std / (i**0.5)
  822. trunc_normal_f32_(layer.self_attn.out_proj.weight, std=std)
  823. trunc_normal_f32_(layer.mlp.fc2.weight, std=std)
  824. std = init_std / (len(module.self_attention_layers) + 1) ** 0.5
  825. trunc_normal_f32_(module.cross_attention_layer.mlp.fc2.weight, std=std)
  826. elif isinstance(module, VJEPA2PredictorEmbeddings):
  827. if module.zero_init_mask_tokens:
  828. module.mask_tokens.data.zero_()
  829. else:
  830. trunc_normal_f32_(module.mask_tokens, std=init_std)
  831. elif isinstance(module, (nn.Linear, nn.Conv2d, nn.Conv3d)):
  832. trunc_normal_f32_(module.weight, std=init_std)
  833. if module.bias is not None:
  834. module.bias.data.zero_()
  835. elif isinstance(module, nn.LayerNorm):
  836. module.bias.data.zero_()
  837. module.weight.data.fill_(1.0)
  838. def _convert_head_mask_to_5d(head_mask, num_hidden_layers):
  839. """
  840. Inputs:
  841. - head_mask: bsz x seq_length x seq_length | None
  842. Returns
  843. - [num_hidden_layers x batch x num_heads x seq_length x seq_length] | [num_hidden_layers]
  844. """
  845. if head_mask is not None:
  846. head_mask = head_mask.unsqueeze(1).unsqueeze(0)
  847. head_mask = head_mask.expand(num_hidden_layers, -1, -1, -1, -1)
  848. else:
  849. head_mask = [None] * num_hidden_layers
  850. return head_mask
  851. @auto_docstring
  852. class VJEPA2Model(VJEPA2PreTrainedModel):
  853. def __init__(self, config: VJEPA2Config):
  854. super().__init__(config)
  855. self.config = config
  856. self.encoder = VJEPA2Encoder(config)
  857. self.predictor = VJEPA2Predictor(config)
  858. # Initialize weights and apply final processing
  859. self.post_init()
  860. def get_input_embeddings(self) -> VJEPA2PatchEmbeddings3D:
  861. return self.encoder.embeddings.patch_embeddings
  862. @can_return_tuple
  863. @auto_docstring
  864. def forward(
  865. self,
  866. pixel_values_videos: torch.Tensor,
  867. context_head_mask: Optional[torch.Tensor] = None,
  868. context_mask: Optional[list[torch.Tensor]] = None,
  869. target_head_mask: Optional[torch.Tensor] = None,
  870. target_mask: Optional[list[torch.Tensor]] = None,
  871. skip_predictor: bool = False,
  872. output_attentions: Optional[bool] = None,
  873. output_hidden_states: Optional[bool] = None,
  874. **kwargs,
  875. ) -> VJEPA2WithMaskedInputModelOutput:
  876. r"""
  877. context_head_mask (`torch.Tensor` with shape `[num_heads]` or `[num_hidden_layers x num_heads]`, *optional*):
  878. The mask indicating if we should keep the heads or not (1.0 for keep, 0.0 for discard) for the context.
  879. context_mask (`torch.Tensor` with shape `[batch_size, patch_size, 1]`, *optional*):
  880. The mask position ids indicating which encoder output patches are going to be exposed to the predictor.
  881. By default, this mask is created as torch.arange(N).unsqueeze(0).repeat(B,1), indicating full context
  882. available to the predictor.
  883. target_head_mask (`torch.Tensor` with shape `[num_heads]` or `[num_hidden_layers x num_heads]`, *optional*):
  884. The mask indicating if we should keep the heads or not (1.0 for keep, 0.0 for discard) for the target.
  885. target_mask (`torch.Tensor` with shape `[batch_size, patch_size, 1]`, *optional*):
  886. The mask position ids indicating which encoder output patches are going to be used as a prediction target
  887. for the predictor. By default, this mask is created as torch.arange(N).unsqueeze(0).repeat(B,1), indicating
  888. that the predictor should predict all encoder patches.
  889. skip_predictor (bool):
  890. flag to skip the predictor forward, useful if you just need the encoder outputs
  891. """
  892. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  893. output_hidden_states = (
  894. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  895. )
  896. if pixel_values_videos is None:
  897. raise ValueError("You have to specify pixel_values_videos")
  898. # Prepare head mask if needed
  899. context_head_mask = _convert_head_mask_to_5d(context_head_mask, self.config.num_hidden_layers)
  900. target_head_mask = _convert_head_mask_to_5d(target_head_mask, self.config.pred_num_hidden_layers)
  901. encoder_outputs: BaseModelOutput = self.encoder(
  902. pixel_values_videos=pixel_values_videos,
  903. head_mask=context_head_mask,
  904. output_attentions=output_attentions,
  905. output_hidden_states=output_hidden_states,
  906. )
  907. sequence_output = encoder_outputs.last_hidden_state
  908. if context_mask is None and target_mask is None:
  909. B = pixel_values_videos.size(0)
  910. N = sequence_output.size(1) # ensure we are using dynamic patch size
  911. context_mask = [torch.arange(N, device=pixel_values_videos.device).unsqueeze(0).repeat((B, 1))]
  912. target_mask = [torch.arange(N, device=pixel_values_videos.device).unsqueeze(0).repeat((B, 1))]
  913. if not skip_predictor:
  914. predictor_outputs: BaseModelOutput = self.predictor(
  915. encoder_hidden_states=sequence_output,
  916. context_mask=context_mask,
  917. target_mask=target_mask,
  918. head_mask=target_head_mask,
  919. output_attentions=output_attentions,
  920. output_hidden_states=output_hidden_states,
  921. )
  922. predictor_output = VJEPA2WithMaskedInputPredictorOutput(
  923. last_hidden_state=predictor_outputs.last_hidden_state,
  924. target_hidden_state=apply_masks(sequence_output, target_mask),
  925. hidden_states=predictor_outputs.hidden_states,
  926. attentions=predictor_outputs.attentions,
  927. )
  928. else:
  929. predictor_output = None
  930. encoder_output = VJEPA2WithMaskedInputModelOutput(
  931. last_hidden_state=sequence_output,
  932. masked_hidden_state=apply_masks(sequence_output, context_mask),
  933. hidden_states=encoder_outputs.hidden_states,
  934. attentions=encoder_outputs.attentions,
  935. predictor_output=predictor_output,
  936. )
  937. return encoder_output
  938. def get_vision_features(self, pixel_values_videos) -> torch.Tensor:
  939. encoder_output = self.forward(pixel_values_videos, skip_predictor=True)
  940. return encoder_output.last_hidden_state
  941. @auto_docstring(
  942. custom_intro="""
  943. V-JEPA 2 Model transformer with a video classification head on top (a linear layer on top of the attentive pooler).
  944. """
  945. )
  946. class VJEPA2ForVideoClassification(VJEPA2PreTrainedModel):
  947. def __init__(self, config: VJEPA2Config):
  948. super().__init__(config)
  949. self.num_labels = config.num_labels
  950. self.vjepa2 = VJEPA2Model(config)
  951. # Classifier head
  952. self.pooler = VJEPA2AttentivePooler(config)
  953. self.classifier = nn.Linear(config.hidden_size, config.num_labels, bias=True)
  954. # Initialize weights and apply final processing
  955. self.post_init()
  956. @can_return_tuple
  957. @auto_docstring
  958. def forward(
  959. self,
  960. pixel_values_videos: torch.Tensor,
  961. labels: Optional[torch.Tensor] = None,
  962. output_attentions: Optional[bool] = None,
  963. output_hidden_states: Optional[bool] = None,
  964. ) -> Union[tuple, ImageClassifierOutput]:
  965. r"""
  966. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  967. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  968. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  969. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  970. Examples:
  971. ```python
  972. >>> import torch
  973. >>> import numpy as np
  974. >>> from transformers import AutoVideoProcessor, VJEPA2ForVideoClassification
  975. >>> device = "cuda"
  976. >>> video_processor = AutoVideoProcessor.from_pretrained("facebook/vjepa2-vitl-fpc16-256-ssv2")
  977. >>> model = VJEPA2ForVideoClassification.from_pretrained("facebook/vjepa2-vitl-fpc16-256-ssv2").to(device)
  978. >>> video = np.ones((64, 256, 256, 3)) # 64 frames, 256x256 RGB
  979. >>> inputs = video_processor(video, return_tensors="pt").to(device)
  980. >>> # For inference
  981. >>> with torch.no_grad():
  982. ... outputs = model(**inputs)
  983. >>> logits = outputs.logits
  984. >>> predicted_label = logits.argmax(-1).item()
  985. >>> print(model.config.id2label[predicted_label])
  986. >>> # For training
  987. >>> labels = torch.ones(1, dtype=torch.long, device=device)
  988. >>> loss = model(**inputs, labels=labels).loss
  989. ```"""
  990. outputs = self.vjepa2(
  991. pixel_values_videos=pixel_values_videos,
  992. skip_predictor=True,
  993. output_attentions=output_attentions,
  994. output_hidden_states=output_hidden_states,
  995. )
  996. last_hidden_state = outputs.last_hidden_state
  997. pooler_output = self.pooler(last_hidden_state)
  998. logits = self.classifier(pooler_output)
  999. loss = None
  1000. if labels is not None:
  1001. loss = self.loss_function(pooled_logits=logits, labels=labels, config=self.config)
  1002. return ImageClassifierOutput(
  1003. loss=loss,
  1004. logits=logits,
  1005. hidden_states=outputs.hidden_states,
  1006. attentions=outputs.attentions,
  1007. )
  1008. __all__ = ["VJEPA2Model", "VJEPA2PreTrainedModel", "VJEPA2ForVideoClassification"]