modeling_tvp.py 39 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923
  1. # coding=utf-8
  2. # Copyright 2023 The Intel AIA Team Authors, and HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License=, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing=, software
  11. # distributed under the License is distributed on an "AS IS" BASIS=,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND=, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch TVP Model"""
  16. import math
  17. from dataclasses import dataclass
  18. from typing import Optional
  19. import torch
  20. from torch import nn
  21. from ...activations import ACT2FN
  22. from ...modeling_layers import GradientCheckpointingLayer
  23. from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ModelOutput
  24. from ...modeling_utils import PreTrainedModel
  25. from ...pytorch_utils import prune_linear_layer
  26. from ...utils import auto_docstring, logging
  27. from ...utils.backbone_utils import load_backbone
  28. from .configuration_tvp import TvpConfig
  29. logger = logging.get_logger(__name__)
  30. @dataclass
  31. @auto_docstring
  32. class TvpVideoGroundingOutput(ModelOutput):
  33. r"""
  34. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
  35. Temporal-Distance IoU loss for video grounding.
  36. logits (`torch.FloatTensor` of shape `(batch_size, 2)`):
  37. Contains start_time/duration and end_time/duration. It is the time slot of the videos corresponding to the
  38. input texts.
  39. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  40. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  41. sequence_length)`.
  42. """
  43. loss: Optional[torch.FloatTensor] = None
  44. logits: Optional[torch.FloatTensor] = None
  45. hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  46. attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  47. class TvpLoss(nn.Module):
  48. """
  49. This class computes the losses for `TvpForVideoGrounding`. The process happens in two steps: 1) we compute
  50. hungarian assignment between ground truth boxes and the outputs of the model 2) we supervise each pair of matched
  51. ground-truth / prediction (supervise class and box).
  52. Args:
  53. losses (`list[str]`):
  54. List of all the losses to be applied.
  55. """
  56. def __init__(self, losses):
  57. super().__init__()
  58. self.loss_map = {
  59. "iou": self.loss_iou,
  60. "distance": self.loss_distance,
  61. "duration": self.loss_duration,
  62. }
  63. for loss in losses:
  64. if loss not in self.loss_map:
  65. raise ValueError(f"Loss {loss} not supported")
  66. self.losses = losses
  67. def loss_iou(self, start_time, end_time, candidates_start_time, candidates_end_time, duration):
  68. """
  69. Measure the intersection over union.
  70. """
  71. inter = torch.min(candidates_end_time, end_time) - torch.max(candidates_start_time, start_time)
  72. union = torch.max(candidates_end_time, end_time) - torch.min(candidates_start_time, start_time)
  73. iou = 1 - inter.clamp(min=0) / union
  74. return iou
  75. def loss_distance(self, start_time, end_time, candidates_start_time, candidates_end_time, duration):
  76. """
  77. Measure the distance of mid points.
  78. """
  79. mid_candidates = torch.div(torch.add(candidates_start_time, candidates_end_time), 2.0)
  80. mid_groundtruth = torch.div(torch.add(start_time, end_time), 2.0)
  81. distance_diff = torch.div(
  82. torch.max(mid_candidates, mid_groundtruth) - torch.min(mid_candidates, mid_groundtruth), duration
  83. ).clamp(min=0.2)
  84. return distance_diff
  85. def loss_duration(self, start_time, end_time, candidates_start_time, candidates_end_time, duration):
  86. """
  87. Measure the difference of duration.
  88. """
  89. duration_candidates = torch.sub(candidates_end_time, candidates_start_time)
  90. duration_groundtruth = torch.sub(end_time, start_time)
  91. duration_diff = torch.square(torch.div(torch.sub(duration_candidates, duration_groundtruth), duration))
  92. duration_diff = duration_diff.clamp(min=0.4)
  93. return duration_diff
  94. def forward(self, logits, labels):
  95. """
  96. This performs the loss computation.
  97. Args:
  98. logits (`torch.FloatTensor`):
  99. The output logits of head module.
  100. labels (`list[torch.FloatTensor]`):
  101. List of tensors ([start, end, duration]), which contains start time, end time of the video corresponding to the text, and also the duration.
  102. """
  103. duration, start_time, end_time = labels
  104. candidates = torch.mul(logits, duration)
  105. candidates_start_time, candidates_end_time = candidates[:, 0].float(), candidates[:, 1].float()
  106. losses_dict = {}
  107. for loss in self.losses:
  108. losses_dict.update(
  109. {loss: self.loss_map[loss](start_time, end_time, candidates_start_time, candidates_end_time, duration)}
  110. )
  111. return losses_dict
  112. class TvpVisionModel(nn.Module):
  113. def __init__(self, config):
  114. super().__init__()
  115. self.backbone = load_backbone(config)
  116. if config.backbone_config is not None:
  117. in_channels = config.backbone_config.hidden_sizes[-1]
  118. elif hasattr(self.backbone, "config") and hasattr(self.backbone.config, "hidden_sizes"):
  119. in_channels = self.backbone.config.hidden_sizes[-1]
  120. elif hasattr(self.backbone, "config") and hasattr(self.backbone.config, "hidden_size"):
  121. in_channels = self.backbone.config.hidden_size
  122. else:
  123. raise ValueError("Backbone config not found")
  124. self.grid_encoder_conv = nn.Conv2d(
  125. in_channels,
  126. config.hidden_size,
  127. kernel_size=3,
  128. stride=1,
  129. padding=1,
  130. groups=1,
  131. bias=False,
  132. )
  133. def forward(self, pixel_values):
  134. batch_size, num_frames, num_channels, height, width = pixel_values.shape
  135. # (batch_size * num_frames, num_channels, height, width)
  136. pixel_values = pixel_values.view(batch_size * num_frames, num_channels, height, width)
  137. grid_feat_outputs = self.backbone(pixel_values)["feature_maps"][0]
  138. grid = self.grid_encoder_conv(grid_feat_outputs)
  139. grid = nn.functional.max_pool2d(grid, kernel_size=2, stride=2)
  140. grid = nn.functional.relu(grid, inplace=True)
  141. new_channel, new_height, new_width = grid.shape[-3:]
  142. # (batch_size, num_frames, num_channels, height, width)
  143. grid = grid.view(batch_size, num_frames, new_channel, new_height, new_width)
  144. # (batch_size, num_frames, height, width, num_channels)
  145. grid = grid.permute(0, 1, 3, 4, 2)
  146. return grid
  147. class TvpVisualInputEmbedding(nn.Module):
  148. """
  149. Takes input of both image and video (multi-frame)
  150. """
  151. def __init__(self, config):
  152. super().__init__()
  153. # sequence embedding
  154. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
  155. self.row_position_embeddings = nn.Embedding(config.max_grid_row_position_embeddings, config.hidden_size)
  156. self.col_position_embeddings = nn.Embedding(config.max_grid_col_position_embeddings, config.hidden_size)
  157. self.token_type_embeddings = nn.Embedding(1, config.hidden_size)
  158. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  159. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  160. self.max_grid_row_position_embeddings = config.max_grid_row_position_embeddings
  161. self.max_grid_col_position_embeddings = config.max_grid_col_position_embeddings
  162. def interpolate_pos_encoding(self, embedding: torch.Tensor, height: int, width: int) -> torch.Tensor:
  163. """
  164. This method allows to interpolate the pre-trained pad weights , to be able to use the model on collection of high
  165. resolution images (high resolution videos).
  166. """
  167. h0 = w0 = 1
  168. # if height dimension is to be interpolated
  169. if height > self.max_grid_row_position_embeddings:
  170. h0 = height / self.max_grid_row_position_embeddings
  171. # if width dimension is to be interpolated
  172. if width > self.max_grid_col_position_embeddings:
  173. w0 = width / self.max_grid_col_position_embeddings
  174. embedding = embedding.permute(0, 3, 1, 2) # (batch_size, hidden_dim, height, width)
  175. embedding = nn.functional.interpolate(
  176. embedding,
  177. scale_factor=(h0, w0),
  178. mode="bicubic",
  179. align_corners=False,
  180. )
  181. embedding = embedding.permute(0, 2, 3, 1) # (batch_size, height, width, hidden_dim)
  182. return embedding
  183. def add_2d_positional_embeddings(self, grid, interpolate_pos_encoding: bool = False):
  184. """
  185. Args:
  186. grid: (batch_size, height, width, hidden_dim)
  187. interpolate_pos_encoding: (`bool`, *optional*, defaults to `False`):
  188. Whether to interpolate the pre-trained position encodings.
  189. Returns:
  190. grid + col_position_embeddings.view(*col_shape): (batch_size, *, height, width, hidden_dim)
  191. """
  192. batch_size, height, width, hidden_dim = grid.shape
  193. # add row-wise position embeddings
  194. # (height, )
  195. row_height = min(self.max_grid_row_position_embeddings, height)
  196. row_position_ids = torch.arange(row_height, dtype=torch.long, device=grid.device)
  197. # (height, hidden_dim)
  198. row_position_embeddings = self.row_position_embeddings(row_position_ids)
  199. row_shape = (1,) * (len(grid.shape) - 3) + (row_height, 1, hidden_dim)
  200. # (batch_size, height, 1, hidden_dim)
  201. row_position_embeddings = row_position_embeddings.view(*row_shape)
  202. # add column-wise position embeddings
  203. row_width = min(self.max_grid_col_position_embeddings, width)
  204. col_position_ids = torch.arange(row_width, dtype=torch.long, device=grid.device)
  205. # (width, hidden_dim)
  206. col_position_embeddings = self.col_position_embeddings(col_position_ids)
  207. col_shape = (batch_size, 1, row_width, hidden_dim)
  208. # (batch_size, 1, width, hidden_dim)
  209. col_position_embeddings = col_position_embeddings.view(*col_shape)
  210. # (batch_size, height, width, hidden_dim)
  211. positional_embeddings = row_position_embeddings + col_position_embeddings
  212. # This interpolation gets triggered ONLY when the input image dim is larger in any dimension than the original position embeddings
  213. if interpolate_pos_encoding and (
  214. height > self.max_grid_row_position_embeddings or width > self.max_grid_col_position_embeddings
  215. ):
  216. grid = grid + self.interpolate_pos_encoding(positional_embeddings, height, width)
  217. else:
  218. grid = grid + positional_embeddings
  219. return grid
  220. def forward(self, grid, interpolate_pos_encoding: bool = False):
  221. """
  222. Args:
  223. grid: Array of shape (batch_size, num_frames, height, width, num_channels).
  224. It contains processed frames extracted from videos, and is generated by Tvp image preprocessor. Note,
  225. num_frames can be 1
  226. interpolate_pos_encoding: (bool, *optional*, defaults to `False`):
  227. Whether to interpolate the pre-trained position encodings.
  228. Returns:
  229. embeddings: The embedding of grid with size (batch_size, height*width, num_channels)
  230. """
  231. batch_size, num_frames, height, width, num_channels = grid.shape
  232. # temporal mean pooling, (batch_size, height, width, hidden_size)
  233. grid = grid.mean(1)
  234. grid = self.add_2d_positional_embeddings(grid, interpolate_pos_encoding=interpolate_pos_encoding)
  235. # image token sequence, (batch_size, height*width, num_channels)
  236. visual_tokens = grid.view(batch_size, -1, num_channels)
  237. visual_tokens_shape = visual_tokens.shape[:-1]
  238. device = visual_tokens.device
  239. # image token type embeddings.
  240. token_type_ids = torch.zeros(visual_tokens_shape, dtype=torch.long, device=device)
  241. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  242. embeddings = visual_tokens + token_type_embeddings
  243. embeddings = self.layer_norm(embeddings)
  244. embeddings = self.dropout(embeddings)
  245. return embeddings
  246. class TvpTextInputEmbeddings(nn.Module):
  247. """Construct the embeddings from word, position and token_type embeddings."""
  248. def __init__(self, config):
  249. super().__init__()
  250. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  251. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
  252. self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
  253. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  254. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  255. def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
  256. if input_ids is not None:
  257. input_shape = input_ids.size()
  258. else:
  259. input_shape = inputs_embeds.size()[:-1]
  260. seq_length = input_shape[1]
  261. device = input_ids.device if input_ids is not None else inputs_embeds.device
  262. if position_ids is None:
  263. position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
  264. position_ids = position_ids.unsqueeze(0).expand(input_shape)
  265. if token_type_ids is None:
  266. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
  267. if inputs_embeds is None:
  268. inputs_embeds = self.word_embeddings(input_ids)
  269. position_embeddings = self.position_embeddings(position_ids)
  270. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  271. embeddings = inputs_embeds + position_embeddings + token_type_embeddings
  272. embeddings = self.layer_norm(embeddings)
  273. embeddings = self.dropout(embeddings)
  274. return embeddings
  275. class TvpAttention(nn.Module):
  276. def __init__(self, config):
  277. super().__init__()
  278. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  279. raise ValueError(
  280. f"The hidden size {config.hidden_size} is not a multiple of the number of attention heads {config.num_attention_heads}"
  281. )
  282. self.num_attention_heads = config.num_attention_heads
  283. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  284. self.all_head_size = self.num_attention_heads * self.attention_head_size
  285. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  286. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  287. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  288. self.attn_dropout = nn.Dropout(config.attention_probs_dropout_prob)
  289. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  290. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  291. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  292. self.pruned_heads = set()
  293. def prune_heads(self, heads):
  294. if len(heads) == 0:
  295. return
  296. mask = torch.ones(self.num_attention_heads, self.attention_head_size)
  297. heads = set(heads) - self.pruned_heads # Convert to set and remove already pruned heads
  298. for head in heads:
  299. # Compute how many pruned heads are before the head and move the index accordingly
  300. head = head - sum(1 if h < head else 0 for h in self.pruned_heads)
  301. mask[head] = 0
  302. mask = mask.view(-1).contiguous().eq(1)
  303. index = torch.arange(len(mask))[mask].long()
  304. # Prune linear layers
  305. self.query = prune_linear_layer(self.query, index)
  306. self.key = prune_linear_layer(self.key, index)
  307. self.value = prune_linear_layer(self.value, index)
  308. self.dense = prune_linear_layer(self.dense, index, dim=1)
  309. # Update hyper params and store pruned heads
  310. self.num_attention_heads = self.num_attention_heads - len(heads)
  311. self.all_head_size = self.attention_head_size * self.num_attention_heads
  312. self.pruned_heads = self.pruned_heads.union(heads)
  313. def _reshape(self, tensor: torch.Tensor, sequence_length: int, batch_size: int):
  314. return (
  315. tensor.view(batch_size, sequence_length, self.num_attention_heads, self.attention_head_size)
  316. .transpose(1, 2)
  317. .contiguous()
  318. )
  319. def forward(
  320. self,
  321. hidden_states,
  322. attention_mask=None,
  323. head_mask=None,
  324. output_attentions: Optional[bool] = None,
  325. ):
  326. batch_size, sequence_length = hidden_states.shape[:2]
  327. mixed_query_layer = self.query(hidden_states)
  328. mixed_key_layer = self.key(hidden_states)
  329. mixed_value_layer = self.value(hidden_states)
  330. query_layer = self._reshape(mixed_query_layer, sequence_length, batch_size)
  331. key_layer = self._reshape(mixed_key_layer, sequence_length, batch_size)
  332. value_layer = self._reshape(mixed_value_layer, sequence_length, batch_size)
  333. # Take the dot product between "query" and "key" to get the raw attention scores.
  334. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  335. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  336. if attention_mask is not None:
  337. attention_scores = attention_scores + attention_mask
  338. # Normalize the attention scores to probabilities.
  339. attention_probs = nn.functional.softmax(attention_scores, dim=-1)
  340. # This is actually dropping out entire tokens to attend to, which might
  341. # seem a bit unusual, but is taken from the original Transformer paper.
  342. attention_probs = self.attn_dropout(attention_probs)
  343. # Mask heads if we want to
  344. if head_mask is not None:
  345. attention_probs = attention_probs * head_mask
  346. attn_output = torch.matmul(attention_probs, value_layer)
  347. attn_output = attn_output.transpose(1, 2).contiguous()
  348. attn_output = attn_output.reshape(batch_size, sequence_length, self.all_head_size)
  349. attn_output = self.dense(attn_output)
  350. attn_output = self.dropout(attn_output)
  351. attn_output = self.layer_norm(attn_output + hidden_states)
  352. # add attentions if we output them
  353. outputs = (attn_output, attention_probs) if output_attentions else (attn_output,)
  354. return outputs
  355. # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->Tvp
  356. class TvpIntermediate(nn.Module):
  357. def __init__(self, config):
  358. super().__init__()
  359. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  360. if isinstance(config.hidden_act, str):
  361. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  362. else:
  363. self.intermediate_act_fn = config.hidden_act
  364. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  365. hidden_states = self.dense(hidden_states)
  366. hidden_states = self.intermediate_act_fn(hidden_states)
  367. return hidden_states
  368. class TvpOutputLayer(nn.Module):
  369. def __init__(self, config):
  370. super().__init__()
  371. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  372. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  373. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  374. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  375. hidden_states = self.dense(hidden_states)
  376. hidden_states = self.dropout(hidden_states)
  377. hidden_states = self.layer_norm(hidden_states + input_tensor)
  378. return hidden_states
  379. class TvpEncodeLayer(GradientCheckpointingLayer):
  380. def __init__(self, config):
  381. super().__init__()
  382. self.attention = TvpAttention(config)
  383. self.intermediate = TvpIntermediate(config)
  384. self.output = TvpOutputLayer(config)
  385. def forward(
  386. self,
  387. hidden_states,
  388. attention_mask=None,
  389. head_mask=None,
  390. output_attentions: Optional[bool] = None,
  391. ):
  392. self_attention_outputs = self.attention(
  393. hidden_states,
  394. attention_mask,
  395. head_mask,
  396. output_attentions=output_attentions,
  397. )
  398. attention_output = self_attention_outputs[0]
  399. outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
  400. intermediate_output = self.intermediate(attention_output)
  401. layer_output = self.output(intermediate_output, attention_output)
  402. outputs = (layer_output,) + outputs
  403. return outputs
  404. class TvpEncoder(nn.Module):
  405. def __init__(self, config):
  406. super().__init__()
  407. self.config = config
  408. self.layer = nn.ModuleList([TvpEncodeLayer(config) for _ in range(config.num_hidden_layers)])
  409. self.gradient_checkpointing = False
  410. def forward(
  411. self,
  412. hidden_states,
  413. attention_mask=None,
  414. head_mask: Optional[torch.FloatTensor] = None,
  415. output_attentions: Optional[bool] = None,
  416. output_hidden_states: Optional[bool] = None,
  417. return_dict: Optional[bool] = None,
  418. ):
  419. return_dict = return_dict if return_dict is not None else self.config.return_dict
  420. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  421. output_hidden_states = (
  422. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  423. )
  424. all_hidden_states = ()
  425. all_attentions = ()
  426. for i, layer_module in enumerate(self.layer):
  427. if output_hidden_states:
  428. all_hidden_states = all_hidden_states + (hidden_states,)
  429. layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i], output_attentions)
  430. hidden_states = layer_outputs[0]
  431. if output_attentions:
  432. all_attentions = all_attentions + (layer_outputs[1],)
  433. # Add last layer
  434. if output_hidden_states:
  435. all_hidden_states = all_hidden_states + (hidden_states,)
  436. if not return_dict:
  437. outputs = (hidden_states,)
  438. if output_hidden_states:
  439. outputs = outputs + (all_hidden_states,)
  440. if output_attentions:
  441. outputs = outputs + (all_attentions,)
  442. return outputs # last-layer hidden state, (all hidden states), (all attentions)
  443. return BaseModelOutput(
  444. last_hidden_state=hidden_states,
  445. hidden_states=all_hidden_states if output_hidden_states else None,
  446. attentions=all_attentions if output_attentions else None,
  447. )
  448. # Copied from transformers.models.bert.modeling_bert.BertPooler with Bert->Tvp
  449. class TvpPooler(nn.Module):
  450. def __init__(self, config):
  451. super().__init__()
  452. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  453. self.activation = nn.Tanh()
  454. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  455. # We "pool" the model by simply taking the hidden state corresponding
  456. # to the first token.
  457. first_token_tensor = hidden_states[:, 0]
  458. pooled_output = self.dense(first_token_tensor)
  459. pooled_output = self.activation(pooled_output)
  460. return pooled_output
  461. @auto_docstring
  462. class TvpPreTrainedModel(PreTrainedModel):
  463. config: TvpConfig
  464. base_model_prefix = "model"
  465. supports_gradient_checkpointing = True
  466. def _init_weights(self, module: nn.Module):
  467. """Initialize the weights"""
  468. if isinstance(module, (nn.Linear, nn.Embedding)):
  469. # Slightly different from the TF version which uses truncated_normal for initialization
  470. # cf https://github.com/pytorch/pytorch/pull/5617
  471. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  472. elif isinstance(module, nn.LayerNorm):
  473. module.bias.data.zero_()
  474. module.weight.data.fill_(1.0)
  475. elif isinstance(module, nn.Conv2d):
  476. nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
  477. if module.bias is not None:
  478. nn.init.constant_(module.bias, 0)
  479. elif isinstance(module, TvpModel):
  480. nn.init.normal_(module.text_prompt)
  481. if isinstance(module, nn.Linear) and module.bias is not None:
  482. module.bias.data.zero_()
  483. if hasattr(module, "pad_up"):
  484. nn.init.normal_(module.pad_up)
  485. if hasattr(module, "pad_down"):
  486. nn.init.normal_(module.pad_down)
  487. if hasattr(module, "pad_left"):
  488. nn.init.normal_(module.pad_left)
  489. if hasattr(module, "pad_right"):
  490. nn.init.normal_(module.pad_right)
  491. class TvpFrameDownPadPrompter(nn.Module):
  492. """
  493. Pad frames extracted from videos only at the bottom.
  494. """
  495. def __init__(self, config):
  496. if config.visual_prompter_apply not in ("add", "replace", "remove"):
  497. raise ValueError("`visual_prompter_apply` must be in (add, replace, remove)")
  498. super().__init__()
  499. self.visual_prompt_size = config.visual_prompt_size
  500. self.frame_num = config.frame_num
  501. self.max_img_size = config.max_img_size
  502. self.visual_prompter_apply = config.visual_prompter_apply
  503. self.pad_down = nn.Parameter(
  504. torch.randn([1, config.frame_num, 3, config.visual_prompt_size, config.max_img_size])
  505. )
  506. def forward(self, pixel_values):
  507. if self.visual_prompter_apply != "add":
  508. visual_prompt_mask = torch.ones(
  509. [self.max_img_size, self.max_img_size], dtype=pixel_values.dtype, device=pixel_values.device
  510. )
  511. visual_prompt_mask[self.max_img_size - self.visual_prompt_size : self.max_img_size, :] = 0.0
  512. pixel_values *= visual_prompt_mask
  513. if self.visual_prompter_apply != "remove":
  514. prompt = torch.zeros(
  515. [pixel_values.shape[0], pixel_values.shape[1], 3, self.max_img_size, self.max_img_size],
  516. device=pixel_values.device,
  517. )
  518. start_point = self.max_img_size - self.visual_prompt_size
  519. prompt[:, :, :, start_point : self.max_img_size, :] = self.pad_down
  520. pixel_values += prompt.to(pixel_values.dtype)
  521. return pixel_values
  522. class TvpFramePadPrompter(nn.Module):
  523. """
  524. Pad frames extracted from videos in the surroundings.
  525. """
  526. def __init__(self, config):
  527. if config.visual_prompter_apply not in ("add", "replace", "remove"):
  528. raise ValueError("`visual_prompter_apply` must be in (add, replace, remove)")
  529. super().__init__()
  530. self.num_frames = config.num_frames
  531. self.max_img_size = config.max_img_size
  532. self.visual_prompter_apply = config.visual_prompter_apply
  533. self.base_size = config.max_img_size - config.visual_prompt_size * 2
  534. self.pad_up = nn.Parameter(
  535. torch.randn([1, config.num_frames, 3, config.visual_prompt_size, config.max_img_size])
  536. )
  537. self.pad_down = nn.Parameter(
  538. torch.randn([1, config.num_frames, 3, config.visual_prompt_size, config.max_img_size])
  539. )
  540. self.pad_left = nn.Parameter(
  541. torch.randn(
  542. [
  543. 1,
  544. config.num_frames,
  545. 3,
  546. config.max_img_size - config.visual_prompt_size * 2,
  547. config.visual_prompt_size,
  548. ]
  549. )
  550. )
  551. self.pad_right = nn.Parameter(
  552. torch.randn(
  553. [
  554. 1,
  555. config.num_frames,
  556. 3,
  557. config.max_img_size - config.visual_prompt_size * 2,
  558. config.visual_prompt_size,
  559. ]
  560. )
  561. )
  562. def interpolate_pad_encoding(self, prompt: torch.Tensor, height: int, width: int) -> torch.Tensor:
  563. """
  564. This method allows to interpolate the pre-trained pad weights, to be able to use the model on collection of high
  565. resolution images (high resolution videos).
  566. """
  567. # creates scale factor from height and width of original image wrt to the config.max_img_size
  568. h0, w0 = height / self.max_img_size, width / self.max_img_size
  569. batch, num_frames, channels, prompt_height, prompt_width = prompt.shape
  570. # reshaping the batch and num_frames dimension into a single one (i.e (b,frames,c,h,w)-->(b*frames,c,h,w)), to apply bicubic interpolation
  571. prompt = prompt.reshape(batch * num_frames, channels, prompt_height, prompt_width)
  572. prompt = nn.functional.interpolate(
  573. prompt,
  574. scale_factor=(h0, w0),
  575. mode="bicubic",
  576. align_corners=False,
  577. )
  578. # reversing back to (batch,frames,channels,height,width), where height and width is the new interpolated height and width
  579. prompt = prompt.reshape(batch, num_frames, channels, height, width)
  580. return prompt
  581. def forward(self, pixel_values, interpolate_pad_encoding: bool = False):
  582. height, width = (
  583. (pixel_values.shape[-2], pixel_values.shape[-1])
  584. if interpolate_pad_encoding
  585. else (self.max_img_size, self.max_img_size)
  586. )
  587. if self.visual_prompter_apply not in ("add", "remove", "replace"):
  588. raise ValueError(f"Invalid visual_prompter_apply value {self.visual_prompter_apply}")
  589. if self.visual_prompter_apply in ("replace", "remove"):
  590. visual_prompt_mask = torch.ones([height, width], dtype=pixel_values.dtype, device=pixel_values.device)
  591. pixel_values *= visual_prompt_mask
  592. if self.visual_prompter_apply in ("replace", "add"):
  593. base = torch.zeros(1, self.num_frames, 3, self.base_size, self.base_size, device=pixel_values.device)
  594. prompt = torch.cat([self.pad_left, base, self.pad_right], dim=4)
  595. prompt = torch.cat([self.pad_up, prompt, self.pad_down], dim=3)
  596. prompt = torch.cat(pixel_values.size(0) * [prompt])
  597. if interpolate_pad_encoding:
  598. prompt = self.interpolate_pad_encoding(prompt, height, width)
  599. pixel_values = pixel_values + prompt.to(pixel_values.dtype)
  600. return pixel_values
  601. TVP_PROMPTER_CLASSES_MAPPING = {
  602. "framedownpad": TvpFrameDownPadPrompter,
  603. "framepad": TvpFramePadPrompter,
  604. }
  605. @auto_docstring(
  606. custom_intro="""
  607. The bare Tvp Model transformer outputting BaseModelOutputWithPooling object without any specific head on top.
  608. """
  609. )
  610. class TvpModel(TvpPreTrainedModel):
  611. def __init__(self, config):
  612. super().__init__(config)
  613. self.config = config
  614. self.vision_model = TvpVisionModel(config)
  615. self.embeddings = TvpTextInputEmbeddings(config)
  616. self.visual_embeddings = TvpVisualInputEmbedding(config)
  617. self.encoder = TvpEncoder(config)
  618. self.pooler = TvpPooler(config)
  619. self.text_prompt = nn.Parameter(torch.randn([1, 10, config.hidden_size]))
  620. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  621. if config.visual_prompter_type not in TVP_PROMPTER_CLASSES_MAPPING:
  622. raise ValueError("`visual_prompter_type` must be in (framedownpad, framepad)")
  623. self.visual_prompter = TVP_PROMPTER_CLASSES_MAPPING[config.visual_prompter_type](config)
  624. self.post_init()
  625. def get_input_embeddings(self):
  626. return self.embeddings.word_embeddings
  627. def set_input_embeddings(self, value):
  628. self.embeddings.word_embeddings = value
  629. def _prune_heads(self, heads_to_prune):
  630. """Prunes heads of the model.
  631. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base class PreTrainedModel
  632. """
  633. for layer, heads in heads_to_prune.items():
  634. self.encoder.layer[layer].attention.prune_heads(heads)
  635. @auto_docstring
  636. def forward(
  637. self,
  638. input_ids: Optional[torch.LongTensor] = None,
  639. pixel_values: Optional[torch.FloatTensor] = None,
  640. attention_mask: Optional[torch.LongTensor] = None,
  641. head_mask: Optional[torch.FloatTensor] = None,
  642. output_attentions: Optional[bool] = None,
  643. output_hidden_states: Optional[bool] = None,
  644. return_dict: Optional[bool] = None,
  645. interpolate_pos_encoding: bool = False,
  646. ):
  647. r"""
  648. Examples:
  649. ```python
  650. >>> import torch
  651. >>> from transformers import AutoConfig, AutoTokenizer, TvpModel
  652. >>> model = TvpModel.from_pretrained("Jiqing/tiny-random-tvp")
  653. >>> tokenizer = AutoTokenizer.from_pretrained("Jiqing/tiny-random-tvp")
  654. >>> pixel_values = torch.rand(1, 1, 3, 448, 448)
  655. >>> text_inputs = tokenizer("This is an example input", return_tensors="pt")
  656. >>> output = model(text_inputs.input_ids, pixel_values, text_inputs.attention_mask)
  657. ```"""
  658. return_dict = return_dict if return_dict is not None else self.config.return_dict
  659. # Add visual prompt, it compensates for the spatiotemporal information loss in 2D visual features.
  660. pixel_values = self.vision_model(
  661. self.visual_prompter(pixel_values, interpolate_pad_encoding=interpolate_pos_encoding)
  662. )
  663. # (batch_size, sequence_length, hidden_size)
  664. text_embedding_output = self.embeddings(input_ids=input_ids)
  665. # (batch_size, visual_sequence_length, hidden_size)
  666. visual_embedding_output = self.visual_embeddings(
  667. pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
  668. )
  669. if attention_mask is not None:
  670. # (batch_size, visual_sequence_length)
  671. visual_attention_mask = attention_mask.new_ones(visual_embedding_output.shape[:2])
  672. pt_mask = torch.ones(attention_mask.shape[0], 10).to(
  673. device=attention_mask.device, dtype=attention_mask.dtype
  674. )
  675. attention_mask = torch.cat([pt_mask, attention_mask, visual_attention_mask], dim=-1)
  676. # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
  677. # ourselves in which case we just need to make it broadcastable to all heads.
  678. attention_mask = self.get_extended_attention_mask(attention_mask, input_ids.size()).to(input_ids.device)
  679. text_prompt = self.text_prompt.expand(text_embedding_output.shape[0], -1, -1)
  680. # (batch_size, sequence_length + visual_sequence_length, hidden_size)
  681. embedding_output = torch.cat([text_prompt, text_embedding_output, visual_embedding_output], dim=1)
  682. encoder_outputs = self.encoder(
  683. embedding_output,
  684. attention_mask=attention_mask,
  685. head_mask=self.get_head_mask(head_mask, self.config.num_hidden_layers),
  686. output_attentions=output_attentions,
  687. output_hidden_states=output_hidden_states,
  688. return_dict=return_dict,
  689. )
  690. last_hidden_state = encoder_outputs.last_hidden_state if return_dict else encoder_outputs[0]
  691. pooled_output = self.pooler(last_hidden_state)
  692. last_hidden_state = self.dropout(last_hidden_state)
  693. pooled_output = self.dropout(pooled_output)
  694. if not return_dict:
  695. return (last_hidden_state, pooled_output) + encoder_outputs[1:]
  696. return BaseModelOutputWithPooling(
  697. last_hidden_state=last_hidden_state,
  698. pooler_output=pooled_output,
  699. hidden_states=encoder_outputs.hidden_states,
  700. attentions=encoder_outputs.attentions,
  701. )
  702. class TvpVideoGroundingHead(nn.Module):
  703. def __init__(self, config):
  704. super().__init__()
  705. self.layer_0 = nn.Linear(config.hidden_size, config.hidden_size * 2)
  706. self.layer_1 = nn.Linear(config.hidden_size * 2, 2)
  707. self.activation_0 = nn.ReLU()
  708. self.activation_1 = nn.Sigmoid()
  709. def forward(self, pooler_output):
  710. logits = self.activation_0(self.layer_0(pooler_output))
  711. logits = self.activation_1(self.layer_1(logits))
  712. return logits
  713. @auto_docstring(
  714. custom_intro="""
  715. Tvp Model with a video grounding head on top computing IoU, distance, and duration loss.
  716. """
  717. )
  718. class TvpForVideoGrounding(TvpPreTrainedModel):
  719. def __init__(self, config):
  720. super().__init__(config)
  721. self.config = config
  722. self.model = TvpModel(config)
  723. self.video_grounding_head = TvpVideoGroundingHead(config)
  724. self.post_init()
  725. @auto_docstring
  726. def forward(
  727. self,
  728. input_ids: Optional[torch.LongTensor] = None,
  729. pixel_values: Optional[torch.FloatTensor] = None,
  730. attention_mask: Optional[torch.LongTensor] = None,
  731. labels: Optional[tuple[torch.Tensor]] = None,
  732. head_mask: Optional[torch.FloatTensor] = None,
  733. output_attentions: Optional[bool] = None,
  734. output_hidden_states: Optional[bool] = None,
  735. return_dict: Optional[bool] = None,
  736. interpolate_pos_encoding: bool = False,
  737. ):
  738. r"""
  739. labels (`torch.FloatTensor` of shape `(batch_size, 3)`, *optional*):
  740. The labels contains duration, start time, and end time of the video corresponding to the text.
  741. Examples:
  742. ```python
  743. >>> import torch
  744. >>> from transformers import AutoConfig, AutoTokenizer, TvpForVideoGrounding
  745. >>> model = TvpForVideoGrounding.from_pretrained("Jiqing/tiny-random-tvp")
  746. >>> tokenizer = AutoTokenizer.from_pretrained("Jiqing/tiny-random-tvp")
  747. >>> pixel_values = torch.rand(1, 1, 3, 448, 448)
  748. >>> text_inputs = tokenizer("This is an example input", return_tensors="pt")
  749. >>> output = model(text_inputs.input_ids, pixel_values, text_inputs.attention_mask)
  750. ```"""
  751. return_dict = return_dict if return_dict is not None else self.config.return_dict
  752. outputs = self.model(
  753. input_ids,
  754. pixel_values,
  755. attention_mask,
  756. head_mask=head_mask,
  757. output_attentions=output_attentions,
  758. output_hidden_states=output_hidden_states,
  759. return_dict=return_dict,
  760. interpolate_pos_encoding=interpolate_pos_encoding,
  761. )
  762. pooler_output = outputs[1]
  763. logits = self.video_grounding_head(pooler_output)
  764. loss = None
  765. if labels is not None:
  766. criterion = TvpLoss(["iou", "distance", "duration"])
  767. criterion.to(self.device)
  768. loss_dict = criterion(logits, labels)
  769. loss = (
  770. loss_dict["iou"]
  771. + self.config.distance_loss_weight * loss_dict["distance"]
  772. + self.config.duration_loss_weight * loss_dict["duration"]
  773. )
  774. if not return_dict:
  775. outputs = (logits,) + outputs[2:]
  776. if loss is not None:
  777. outputs = (loss,) + outputs
  778. return outputs
  779. return TvpVideoGroundingOutput(
  780. loss=loss,
  781. logits=logits,
  782. hidden_states=outputs.hidden_states,
  783. attentions=outputs.attentions,
  784. )
  785. __all__ = ["TvpModel", "TvpPreTrainedModel", "TvpForVideoGrounding"]