modeling_vilt.py 56 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351
  1. # coding=utf-8
  2. # Copyright 2022 NAVER AI Labs and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch ViLT model."""
  16. import collections.abc
  17. import math
  18. from dataclasses import dataclass
  19. from typing import Optional, Union
  20. import torch
  21. from torch import nn
  22. from torch.nn import CrossEntropyLoss
  23. from ...activations import ACT2FN
  24. from ...modeling_layers import GradientCheckpointingLayer
  25. from ...modeling_outputs import (
  26. BaseModelOutput,
  27. BaseModelOutputWithPooling,
  28. MaskedLMOutput,
  29. ModelOutput,
  30. SequenceClassifierOutput,
  31. TokenClassifierOutput,
  32. )
  33. from ...modeling_utils import PreTrainedModel
  34. from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer
  35. from ...utils import auto_docstring, logging
  36. from .configuration_vilt import ViltConfig
  37. logger = logging.get_logger(__name__)
  38. @dataclass
  39. @auto_docstring(
  40. custom_intro="""
  41. Class for outputs of [`ViltForImagesAndTextClassification`].
  42. """
  43. )
  44. class ViltForImagesAndTextClassificationOutput(ModelOutput):
  45. r"""
  46. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  47. Classification (or regression if config.num_labels==1) loss.
  48. logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
  49. Classification (or regression if config.num_labels==1) scores (before SoftMax).
  50. hidden_states (`list[tuple(torch.FloatTensor)]`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  51. List of tuples of `torch.FloatTensor` (one for each image-text pair, each tuple containing the output of
  52. the embeddings + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
  53. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
  54. """
  55. loss: Optional[torch.FloatTensor] = None
  56. logits: Optional[torch.FloatTensor] = None
  57. hidden_states: Optional[list[tuple[torch.FloatTensor]]] = None
  58. attentions: Optional[list[tuple[torch.FloatTensor]]] = None
  59. class ViltEmbeddings(nn.Module):
  60. """
  61. Construct the text and patch embeddings.
  62. Text embeddings are equivalent to BERT embeddings.
  63. Patch embeddings are equivalent to ViT embeddings.
  64. """
  65. def __init__(self, config):
  66. super().__init__()
  67. # text embeddings
  68. self.text_embeddings = TextEmbeddings(config)
  69. # patch embeddings
  70. self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
  71. self.patch_embeddings = ViltPatchEmbeddings(config)
  72. num_patches = self.patch_embeddings.num_patches
  73. self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
  74. # modality type (text/patch) embeddings
  75. self.token_type_embeddings = nn.Embedding(config.modality_type_vocab_size, config.hidden_size)
  76. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  77. self.config = config
  78. def visual_embed(self, pixel_values, pixel_mask, max_image_length=200):
  79. _, _, ph, pw = self.patch_embeddings.projection.weight.shape
  80. x = self.patch_embeddings(pixel_values)
  81. x_mask = pixel_mask[:, None, :, :].float()
  82. x_mask = nn.functional.interpolate(x_mask, size=(x.shape[2], x.shape[3])).long()
  83. x_h = x_mask[:, 0].sum(dim=1)[:, 0]
  84. x_w = x_mask[:, 0].sum(dim=2)[:, 0]
  85. batch_size, num_channels, height, width = x.shape
  86. patch_dim = self.config.image_size // self.config.patch_size
  87. spatial_pos = self.position_embeddings[:, 1:, :].transpose(1, 2).view(1, num_channels, patch_dim, patch_dim)
  88. pos_embed = torch.cat(
  89. [
  90. nn.functional.pad(
  91. nn.functional.interpolate(
  92. spatial_pos,
  93. size=(h, w),
  94. mode="bilinear",
  95. align_corners=True,
  96. ),
  97. (0, width - w, 0, height - h),
  98. )
  99. for h, w in zip(x_h, x_w)
  100. ],
  101. dim=0,
  102. )
  103. pos_embed = pos_embed.flatten(2).transpose(1, 2)
  104. x = x.flatten(2).transpose(1, 2)
  105. # Set `device` here, otherwise `patch_index` will always be on `CPU` and will fail near the end for torch>=1.13
  106. patch_index = torch.stack(
  107. meshgrid(torch.arange(x_mask.shape[-2]), torch.arange(x_mask.shape[-1]), indexing="ij"), dim=-1
  108. ).to(device=x_mask.device)
  109. patch_index = patch_index[None, None, :, :, :]
  110. patch_index = patch_index.expand(x_mask.shape[0], x_mask.shape[1], -1, -1, -1)
  111. patch_index = patch_index.flatten(1, 3)
  112. x_mask = x_mask.flatten(1)
  113. if max_image_length < 0 or max_image_length is None or not isinstance(max_image_length, int):
  114. # suppose aug is 800 x 1333, then, maximum effective res is 800 x 1333 (if one side gets bigger, the other will be constrained and be shrunk)
  115. # (800 // self.patch_size) * (1333 // self.patch_size) is the maximum number of patches that single image can get.
  116. # if self.patch_size = 32, 25 * 41 = 1025
  117. # if res is 384 x 640, 12 * 20 = 240
  118. effective_resolution = x_h * x_w
  119. max_image_length = effective_resolution.max()
  120. else:
  121. effective_resolution = x_h * x_w
  122. max_image_length = min(effective_resolution.max(), max_image_length)
  123. valid_idx = x_mask.nonzero(as_tuple=False)
  124. non_valid_idx = (1 - x_mask).nonzero(as_tuple=False)
  125. unique_rows = valid_idx[:, 0].unique()
  126. valid_row_idx = [valid_idx[valid_idx[:, 0] == u] for u in unique_rows]
  127. non_valid_row_idx = [non_valid_idx[non_valid_idx[:, 0] == u] for u in unique_rows]
  128. valid_nums = [v.size(0) for v in valid_row_idx]
  129. non_valid_nums = [v.size(0) for v in non_valid_row_idx]
  130. pad_nums = [max_image_length - v for v in valid_nums]
  131. select = []
  132. for i, (v, nv, p) in enumerate(zip(valid_nums, non_valid_nums, pad_nums)):
  133. if p <= 0:
  134. valid_choice = torch.multinomial(torch.ones(v).float(), max_image_length)
  135. select.append(valid_row_idx[i][valid_choice])
  136. else:
  137. pad_choice = torch.multinomial(torch.ones(nv).float(), p, replacement=True)
  138. select.append(torch.cat([valid_row_idx[i], non_valid_row_idx[i][pad_choice]], dim=0))
  139. select = torch.cat(select, dim=0)
  140. x = x[select[:, 0], select[:, 1]].view(batch_size, -1, num_channels)
  141. x_mask = x_mask[select[:, 0], select[:, 1]].view(batch_size, -1)
  142. # `patch_index` should be on the same device as `select`, which is ensured at definition time.
  143. patch_index = patch_index[select[:, 0], select[:, 1]].view(batch_size, -1, 2)
  144. pos_embed = pos_embed[select[:, 0], select[:, 1]].view(batch_size, -1, num_channels)
  145. cls_tokens = self.cls_token.expand(batch_size, -1, -1)
  146. x = torch.cat((cls_tokens, x), dim=1)
  147. pos_embed = torch.cat(
  148. (self.position_embeddings[:, 0, :][:, None, :].expand(batch_size, -1, -1), pos_embed), dim=1
  149. )
  150. x = x + pos_embed
  151. x = self.dropout(x)
  152. x_mask = torch.cat([torch.ones(x_mask.shape[0], 1).to(x_mask), x_mask], dim=1)
  153. return x, x_mask, (patch_index, (height, width))
  154. def forward(
  155. self,
  156. input_ids,
  157. attention_mask,
  158. token_type_ids,
  159. pixel_values,
  160. pixel_mask,
  161. inputs_embeds,
  162. image_embeds,
  163. image_token_type_idx=1,
  164. ):
  165. # PART 1: text embeddings
  166. text_embeds = self.text_embeddings(
  167. input_ids=input_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
  168. )
  169. # PART 2: patch embeddings (with interpolated position encodings)
  170. if image_embeds is None:
  171. image_embeds, image_masks, patch_index = self.visual_embed(
  172. pixel_values, pixel_mask, max_image_length=self.config.max_image_length
  173. )
  174. else:
  175. image_masks = pixel_mask.flatten(1)
  176. # PART 3: add modality type embeddings
  177. # 0 indicates text, 1 indicates image, 2 is optionally used when a second image is provided (NLVR2)
  178. if image_token_type_idx is None:
  179. image_token_type_idx = 1
  180. text_embeds = text_embeds + self.token_type_embeddings(
  181. torch.zeros_like(attention_mask, dtype=torch.long, device=text_embeds.device)
  182. )
  183. image_embeds = image_embeds + self.token_type_embeddings(
  184. torch.full_like(image_masks, image_token_type_idx, dtype=torch.long, device=text_embeds.device)
  185. )
  186. # PART 4: concatenate
  187. embeddings = torch.cat([text_embeds, image_embeds], dim=1)
  188. masks = torch.cat([attention_mask, image_masks], dim=1)
  189. return embeddings, masks
  190. class TextEmbeddings(nn.Module):
  191. """Construct the embeddings from word, position and token_type embeddings."""
  192. def __init__(self, config):
  193. super().__init__()
  194. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  195. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
  196. self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
  197. # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
  198. # any TensorFlow checkpoint file
  199. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  200. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  201. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  202. self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
  203. self.register_buffer(
  204. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  205. )
  206. self.register_buffer(
  207. "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
  208. )
  209. def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
  210. if input_ids is not None:
  211. input_shape = input_ids.size()
  212. else:
  213. input_shape = inputs_embeds.size()[:-1]
  214. seq_length = input_shape[1]
  215. if position_ids is None:
  216. position_ids = self.position_ids[:, :seq_length]
  217. # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
  218. # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
  219. # issue #5664
  220. if token_type_ids is None:
  221. if hasattr(self, "token_type_ids"):
  222. buffered_token_type_ids = self.token_type_ids[:, :seq_length]
  223. buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
  224. token_type_ids = buffered_token_type_ids_expanded
  225. else:
  226. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
  227. if inputs_embeds is None:
  228. inputs_embeds = self.word_embeddings(input_ids)
  229. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  230. embeddings = inputs_embeds + token_type_embeddings
  231. if self.position_embedding_type == "absolute":
  232. position_embeddings = self.position_embeddings(position_ids)
  233. embeddings += position_embeddings
  234. embeddings = self.LayerNorm(embeddings)
  235. embeddings = self.dropout(embeddings)
  236. return embeddings
  237. class ViltPatchEmbeddings(nn.Module):
  238. """
  239. Image to Patch Embedding.
  240. """
  241. def __init__(self, config):
  242. super().__init__()
  243. image_size, patch_size = config.image_size, config.patch_size
  244. num_channels, hidden_size = config.num_channels, config.hidden_size
  245. image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
  246. patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
  247. num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
  248. self.image_size = image_size
  249. self.patch_size = patch_size
  250. self.num_channels = num_channels
  251. self.num_patches = num_patches
  252. self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
  253. def forward(self, pixel_values):
  254. batch_size, num_channels, height, width = pixel_values.shape
  255. if num_channels != self.num_channels:
  256. raise ValueError(
  257. "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
  258. )
  259. target_dtype = self.projection.weight.dtype
  260. x = self.projection(pixel_values.to(dtype=target_dtype))
  261. return x
  262. class ViltSelfAttention(nn.Module):
  263. def __init__(self, config):
  264. super().__init__()
  265. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  266. raise ValueError(
  267. f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
  268. f"heads {config.num_attention_heads}."
  269. )
  270. self.num_attention_heads = config.num_attention_heads
  271. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  272. self.all_head_size = self.num_attention_heads * self.attention_head_size
  273. self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
  274. self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
  275. self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
  276. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  277. def forward(self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False):
  278. batch_size, seq_length, _ = hidden_states.shape
  279. query_layer = (
  280. self.query(hidden_states)
  281. .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
  282. .transpose(1, 2)
  283. )
  284. key_layer = (
  285. self.key(hidden_states)
  286. .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
  287. .transpose(1, 2)
  288. )
  289. value_layer = (
  290. self.value(hidden_states)
  291. .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
  292. .transpose(1, 2)
  293. )
  294. # Take the dot product between "query" and "key" to get the raw attention scores.
  295. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  296. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  297. if attention_mask is not None:
  298. # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
  299. attention_scores = attention_scores + attention_mask
  300. # Normalize the attention scores to probabilities.
  301. attention_probs = nn.Softmax(dim=-1)(attention_scores)
  302. # This is actually dropping out entire tokens to attend to, which might
  303. # seem a bit unusual, but is taken from the original Transformer paper.
  304. attention_probs = self.dropout(attention_probs)
  305. # Mask heads if we want to
  306. if head_mask is not None:
  307. attention_probs = attention_probs * head_mask
  308. context_layer = torch.matmul(attention_probs, value_layer)
  309. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  310. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  311. context_layer = context_layer.view(*new_context_layer_shape)
  312. outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
  313. return outputs
  314. # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Vilt
  315. class ViltSelfOutput(nn.Module):
  316. """
  317. The residual connection is defined in ViltLayer instead of here (as is the case with other models), due to the
  318. layernorm applied before each block.
  319. """
  320. def __init__(self, config: ViltConfig):
  321. super().__init__()
  322. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  323. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  324. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  325. hidden_states = self.dense(hidden_states)
  326. hidden_states = self.dropout(hidden_states)
  327. return hidden_states
  328. class ViltAttention(nn.Module):
  329. def __init__(self, config):
  330. super().__init__()
  331. self.attention = ViltSelfAttention(config)
  332. self.output = ViltSelfOutput(config)
  333. self.pruned_heads = set()
  334. def prune_heads(self, heads):
  335. if len(heads) == 0:
  336. return
  337. heads, index = find_pruneable_heads_and_indices(
  338. heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
  339. )
  340. # Prune linear layers
  341. self.attention.query = prune_linear_layer(self.attention.query, index)
  342. self.attention.key = prune_linear_layer(self.attention.key, index)
  343. self.attention.value = prune_linear_layer(self.attention.value, index)
  344. self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
  345. # Update hyper params and store pruned heads
  346. self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
  347. self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
  348. self.pruned_heads = self.pruned_heads.union(heads)
  349. def forward(self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False):
  350. self_outputs = self.attention(hidden_states, attention_mask, head_mask, output_attentions)
  351. attention_output = self.output(self_outputs[0], hidden_states)
  352. outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
  353. return outputs
  354. # Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->Vilt
  355. class ViltIntermediate(nn.Module):
  356. def __init__(self, config: ViltConfig):
  357. super().__init__()
  358. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  359. if isinstance(config.hidden_act, str):
  360. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  361. else:
  362. self.intermediate_act_fn = config.hidden_act
  363. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  364. hidden_states = self.dense(hidden_states)
  365. hidden_states = self.intermediate_act_fn(hidden_states)
  366. return hidden_states
  367. # Copied from transformers.models.vit.modeling_vit.ViTOutput with ViT->Vilt
  368. class ViltOutput(nn.Module):
  369. def __init__(self, config: ViltConfig):
  370. super().__init__()
  371. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  372. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  373. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  374. hidden_states = self.dense(hidden_states)
  375. hidden_states = self.dropout(hidden_states)
  376. hidden_states = hidden_states + input_tensor
  377. return hidden_states
  378. class ViltLayer(GradientCheckpointingLayer):
  379. """This corresponds to the Block class in the timm implementation."""
  380. def __init__(self, config):
  381. super().__init__()
  382. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  383. self.seq_len_dim = 1
  384. self.attention = ViltAttention(config)
  385. self.intermediate = ViltIntermediate(config)
  386. self.output = ViltOutput(config)
  387. self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  388. self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  389. def forward(self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False):
  390. self_attention_outputs = self.attention(
  391. self.layernorm_before(hidden_states), # in ViLT, layernorm is applied before self-attention
  392. attention_mask,
  393. head_mask,
  394. output_attentions=output_attentions,
  395. )
  396. attention_output = self_attention_outputs[0]
  397. outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
  398. # first residual connection
  399. hidden_states = attention_output + hidden_states.to(attention_output.device)
  400. # in ViLT, layernorm is also applied after self-attention
  401. layer_output = self.layernorm_after(hidden_states)
  402. layer_output = self.intermediate(layer_output)
  403. # second residual connection is done here
  404. layer_output = self.output(layer_output, hidden_states)
  405. outputs = (layer_output,) + outputs
  406. return outputs
  407. class ViltEncoder(nn.Module):
  408. def __init__(self, config):
  409. super().__init__()
  410. self.config = config
  411. self.layer = nn.ModuleList([ViltLayer(config) for _ in range(config.num_hidden_layers)])
  412. self.gradient_checkpointing = False
  413. def forward(
  414. self,
  415. hidden_states,
  416. attention_mask=None,
  417. head_mask=None,
  418. output_attentions=False,
  419. output_hidden_states=False,
  420. return_dict=True,
  421. ):
  422. all_hidden_states = () if output_hidden_states else None
  423. all_self_attentions = () if output_attentions else None
  424. for i, layer_module in enumerate(self.layer):
  425. if output_hidden_states:
  426. all_hidden_states = all_hidden_states + (hidden_states,)
  427. layer_head_mask = head_mask[i] if head_mask is not None else None
  428. layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions)
  429. hidden_states = layer_outputs[0]
  430. if output_attentions:
  431. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  432. if output_hidden_states:
  433. all_hidden_states = all_hidden_states + (hidden_states,)
  434. if not return_dict:
  435. return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
  436. return BaseModelOutput(
  437. last_hidden_state=hidden_states,
  438. hidden_states=all_hidden_states,
  439. attentions=all_self_attentions,
  440. )
  441. @auto_docstring
  442. class ViltPreTrainedModel(PreTrainedModel):
  443. config: ViltConfig
  444. base_model_prefix = "vilt"
  445. supports_gradient_checkpointing = True
  446. _no_split_modules = ["ViltEmbeddings", "ViltSelfAttention"]
  447. def _init_weights(self, module):
  448. """Initialize the weights"""
  449. if isinstance(module, (nn.Linear, nn.Conv2d)):
  450. # Slightly different from the TF version which uses truncated_normal for initialization
  451. # cf https://github.com/pytorch/pytorch/pull/5617
  452. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  453. if module.bias is not None:
  454. module.bias.data.zero_()
  455. elif isinstance(module, nn.Embedding):
  456. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  457. if module.padding_idx is not None:
  458. module.weight.data[module.padding_idx].zero_()
  459. elif isinstance(module, nn.LayerNorm):
  460. module.bias.data.zero_()
  461. module.weight.data.fill_(1.0)
  462. @auto_docstring
  463. class ViltModel(ViltPreTrainedModel):
  464. def __init__(self, config, add_pooling_layer=True):
  465. r"""
  466. add_pooling_layer (bool, *optional*, defaults to `True`):
  467. Whether to add a pooling layer
  468. """
  469. super().__init__(config)
  470. self.config = config
  471. self.embeddings = ViltEmbeddings(config)
  472. self.encoder = ViltEncoder(config)
  473. self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  474. self.pooler = ViltPooler(config) if add_pooling_layer else None
  475. # Initialize weights and apply final processing
  476. self.post_init()
  477. def get_input_embeddings(self):
  478. return self.embeddings.text_embeddings.word_embeddings
  479. def set_input_embeddings(self, value):
  480. self.embeddings.text_embeddings.word_embeddings = value
  481. def _prune_heads(self, heads_to_prune):
  482. """
  483. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  484. class PreTrainedModel
  485. """
  486. for layer, heads in heads_to_prune.items():
  487. self.encoder.layer[layer].attention.prune_heads(heads)
  488. @auto_docstring
  489. def forward(
  490. self,
  491. input_ids: Optional[torch.LongTensor] = None,
  492. attention_mask: Optional[torch.FloatTensor] = None,
  493. token_type_ids: Optional[torch.LongTensor] = None,
  494. pixel_values: Optional[torch.FloatTensor] = None,
  495. pixel_mask: Optional[torch.LongTensor] = None,
  496. head_mask: Optional[torch.FloatTensor] = None,
  497. inputs_embeds: Optional[torch.FloatTensor] = None,
  498. image_embeds: Optional[torch.FloatTensor] = None,
  499. image_token_type_idx: Optional[int] = None,
  500. output_attentions: Optional[bool] = None,
  501. output_hidden_states: Optional[bool] = None,
  502. return_dict: Optional[bool] = None,
  503. ) -> Union[BaseModelOutputWithPooling, tuple[torch.FloatTensor]]:
  504. r"""
  505. image_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`, *optional*):
  506. Optionally, instead of passing `pixel_values`, you can choose to directly pass an embedded representation.
  507. This is useful if you want more control over how to convert `pixel_values` into patch embeddings.
  508. image_token_type_idx (`int`, *optional*):
  509. - The token type ids for images.
  510. Examples:
  511. ```python
  512. >>> from transformers import ViltProcessor, ViltModel
  513. >>> from PIL import Image
  514. >>> import requests
  515. >>> # prepare image and text
  516. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  517. >>> image = Image.open(requests.get(url, stream=True).raw)
  518. >>> text = "hello world"
  519. >>> processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-mlm")
  520. >>> model = ViltModel.from_pretrained("dandelin/vilt-b32-mlm")
  521. >>> inputs = processor(image, text, return_tensors="pt")
  522. >>> outputs = model(**inputs)
  523. >>> last_hidden_states = outputs.last_hidden_state
  524. ```"""
  525. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  526. output_hidden_states = (
  527. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  528. )
  529. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  530. if input_ids is not None and inputs_embeds is not None:
  531. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  532. elif input_ids is not None:
  533. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  534. input_shape = input_ids.size()
  535. elif inputs_embeds is not None:
  536. input_shape = inputs_embeds.size()[:-1]
  537. else:
  538. raise ValueError("You have to specify either input_ids or inputs_embeds")
  539. text_batch_size, seq_length = input_shape
  540. device = input_ids.device if input_ids is not None else inputs_embeds.device
  541. if attention_mask is None:
  542. attention_mask = torch.ones(((text_batch_size, seq_length)), device=device)
  543. if pixel_values is not None and image_embeds is not None:
  544. raise ValueError("You cannot specify both pixel_values and image_embeds at the same time")
  545. elif pixel_values is None and image_embeds is None:
  546. raise ValueError("You have to specify either pixel_values or image_embeds")
  547. image_batch_size = pixel_values.shape[0] if pixel_values is not None else image_embeds.shape[0]
  548. if image_batch_size != text_batch_size:
  549. raise ValueError("The text inputs and image inputs need to have the same batch size")
  550. if pixel_mask is None:
  551. pixel_mask = torch.ones((image_batch_size, self.config.image_size, self.config.image_size), device=device)
  552. # Prepare head mask if needed
  553. # 1.0 in head_mask indicate we keep the head
  554. # attention_probs has shape bsz x n_heads x N x N
  555. # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
  556. # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
  557. head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
  558. embedding_output, attention_mask = self.embeddings(
  559. input_ids,
  560. attention_mask,
  561. token_type_ids,
  562. pixel_values,
  563. pixel_mask,
  564. inputs_embeds,
  565. image_embeds,
  566. image_token_type_idx=image_token_type_idx,
  567. )
  568. # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
  569. # ourselves in which case we just need to make it broadcastable to all heads.
  570. extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
  571. encoder_outputs = self.encoder(
  572. embedding_output,
  573. attention_mask=extended_attention_mask,
  574. head_mask=head_mask,
  575. output_attentions=output_attentions,
  576. output_hidden_states=output_hidden_states,
  577. return_dict=return_dict,
  578. )
  579. sequence_output = encoder_outputs[0]
  580. sequence_output = self.layernorm(sequence_output)
  581. pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
  582. if not return_dict:
  583. return (sequence_output, pooled_output) + encoder_outputs[1:]
  584. return BaseModelOutputWithPooling(
  585. last_hidden_state=sequence_output,
  586. pooler_output=pooled_output,
  587. hidden_states=encoder_outputs.hidden_states,
  588. attentions=encoder_outputs.attentions,
  589. )
  590. class ViltPooler(nn.Module):
  591. def __init__(self, config):
  592. super().__init__()
  593. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  594. self.activation = nn.Tanh()
  595. def forward(self, hidden_states):
  596. # We "pool" the model by simply taking the hidden state corresponding
  597. # to the first token.
  598. first_token_tensor = hidden_states[:, 0]
  599. pooled_output = self.dense(first_token_tensor)
  600. pooled_output = self.activation(pooled_output)
  601. return pooled_output
  602. @auto_docstring(
  603. custom_intro="""
  604. ViLT Model with a language modeling head on top as done during pretraining.
  605. """
  606. )
  607. class ViltForMaskedLM(ViltPreTrainedModel):
  608. _tied_weights_keys = ["mlm_score.decoder.weight", "mlm_score.decoder.bias"]
  609. def __init__(self, config):
  610. super().__init__(config)
  611. self.vilt = ViltModel(config)
  612. self.mlm_score = ViltMLMHead(config)
  613. # Initialize weights and apply final processing
  614. self.post_init()
  615. def get_output_embeddings(self):
  616. return self.mlm_score.decoder
  617. def set_output_embeddings(self, new_embeddings):
  618. self.mlm_score.decoder = new_embeddings
  619. self.mlm_score.bias = new_embeddings.bias
  620. @auto_docstring
  621. def forward(
  622. self,
  623. input_ids: Optional[torch.LongTensor] = None,
  624. attention_mask: Optional[torch.FloatTensor] = None,
  625. token_type_ids: Optional[torch.LongTensor] = None,
  626. pixel_values: Optional[torch.FloatTensor] = None,
  627. pixel_mask: Optional[torch.LongTensor] = None,
  628. head_mask: Optional[torch.FloatTensor] = None,
  629. inputs_embeds: Optional[torch.FloatTensor] = None,
  630. image_embeds: Optional[torch.FloatTensor] = None,
  631. labels: Optional[torch.LongTensor] = None,
  632. output_attentions: Optional[bool] = None,
  633. output_hidden_states: Optional[bool] = None,
  634. return_dict: Optional[bool] = None,
  635. ) -> Union[MaskedLMOutput, tuple[torch.FloatTensor]]:
  636. r"""
  637. image_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`, *optional*):
  638. Optionally, instead of passing `pixel_values`, you can choose to directly pass an embedded representation.
  639. This is useful if you want more control over how to convert `pixel_values` into patch embeddings.
  640. labels (*torch.LongTensor* of shape *(batch_size, sequence_length)*, *optional*):
  641. Labels for computing the masked language modeling loss. Indices should be in *[-100, 0, ...,
  642. config.vocab_size]* (see *input_ids* docstring) Tokens with indices set to *-100* are ignored (masked), the
  643. loss is only computed for the tokens with labels in *[0, ..., config.vocab_size]*
  644. Examples:
  645. ```python
  646. >>> from transformers import ViltProcessor, ViltForMaskedLM
  647. >>> import requests
  648. >>> from PIL import Image
  649. >>> import re
  650. >>> import torch
  651. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  652. >>> image = Image.open(requests.get(url, stream=True).raw)
  653. >>> text = "a bunch of [MASK] laying on a [MASK]."
  654. >>> processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-mlm")
  655. >>> model = ViltForMaskedLM.from_pretrained("dandelin/vilt-b32-mlm")
  656. >>> # prepare inputs
  657. >>> encoding = processor(image, text, return_tensors="pt")
  658. >>> # forward pass
  659. >>> outputs = model(**encoding)
  660. >>> tl = len(re.findall("\[MASK\]", text))
  661. >>> inferred_token = [text]
  662. >>> # gradually fill in the MASK tokens, one by one
  663. >>> with torch.no_grad():
  664. ... for i in range(tl):
  665. ... encoded = processor.tokenizer(inferred_token)
  666. ... input_ids = torch.tensor(encoded.input_ids)
  667. ... encoded = encoded["input_ids"][0][1:-1]
  668. ... outputs = model(input_ids=input_ids, pixel_values=encoding.pixel_values)
  669. ... mlm_logits = outputs.logits[0] # shape (seq_len, vocab_size)
  670. ... # only take into account text features (minus CLS and SEP token)
  671. ... mlm_logits = mlm_logits[1 : input_ids.shape[1] - 1, :]
  672. ... mlm_values, mlm_ids = mlm_logits.softmax(dim=-1).max(dim=-1)
  673. ... # only take into account text
  674. ... mlm_values[torch.tensor(encoded) != 103] = 0
  675. ... select = mlm_values.argmax().item()
  676. ... encoded[select] = mlm_ids[select].item()
  677. ... inferred_token = [processor.decode(encoded)]
  678. >>> selected_token = ""
  679. >>> encoded = processor.tokenizer(inferred_token)
  680. >>> output = processor.decode(encoded.input_ids[0], skip_special_tokens=True)
  681. >>> print(output)
  682. a bunch of cats laying on a couch.
  683. ```"""
  684. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  685. outputs = self.vilt(
  686. input_ids,
  687. attention_mask=attention_mask,
  688. token_type_ids=token_type_ids,
  689. pixel_values=pixel_values,
  690. pixel_mask=pixel_mask,
  691. head_mask=head_mask,
  692. inputs_embeds=inputs_embeds,
  693. image_embeds=image_embeds,
  694. output_attentions=output_attentions,
  695. output_hidden_states=output_hidden_states,
  696. return_dict=return_dict,
  697. )
  698. sequence_output, pooled_output = outputs[:2]
  699. # split up final hidden states into text and image features
  700. text_seq_len = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
  701. text_features, _ = (sequence_output[:, :text_seq_len], sequence_output[:, text_seq_len:])
  702. mlm_logits = self.mlm_score(text_features)
  703. masked_lm_loss = None
  704. if labels is not None:
  705. loss_fct = CrossEntropyLoss() # -100 index = padding token
  706. # move labels to correct device to enable PP
  707. labels = labels.to(mlm_logits.device)
  708. masked_lm_loss = loss_fct(mlm_logits.view(-1, self.config.vocab_size), labels.view(-1))
  709. if not return_dict:
  710. output = (mlm_logits,) + outputs[2:]
  711. return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
  712. return MaskedLMOutput(
  713. loss=masked_lm_loss,
  714. logits=mlm_logits,
  715. hidden_states=outputs.hidden_states,
  716. attentions=outputs.attentions,
  717. )
  718. class ViltPredictionHeadTransform(nn.Module):
  719. def __init__(self, config):
  720. super().__init__()
  721. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  722. if isinstance(config.hidden_act, str):
  723. self.transform_act_fn = ACT2FN[config.hidden_act]
  724. else:
  725. self.transform_act_fn = config.hidden_act
  726. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  727. def forward(self, hidden_states):
  728. hidden_states = self.dense(hidden_states)
  729. hidden_states = self.transform_act_fn(hidden_states)
  730. hidden_states = self.LayerNorm(hidden_states)
  731. return hidden_states
  732. class ViltMLMHead(nn.Module):
  733. def __init__(self, config, weight=None):
  734. super().__init__()
  735. self.config = config
  736. self.transform = ViltPredictionHeadTransform(config)
  737. self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  738. self.bias = nn.Parameter(torch.zeros(config.vocab_size))
  739. if weight is not None:
  740. self.decoder.weight = weight
  741. # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
  742. self.decoder.bias = self.bias
  743. def _tie_weights(self):
  744. self.decoder.bias = self.bias
  745. def forward(self, x):
  746. x = self.transform(x)
  747. x = self.decoder(x)
  748. return x
  749. @auto_docstring(
  750. custom_intro="""
  751. Vilt Model transformer with a classifier head on top (a linear layer on top of the final hidden state of the [CLS]
  752. token) for visual question answering, e.g. for VQAv2.
  753. """
  754. )
  755. class ViltForQuestionAnswering(ViltPreTrainedModel):
  756. def __init__(self, config):
  757. super().__init__(config)
  758. self.num_labels = config.num_labels
  759. self.vilt = ViltModel(config)
  760. # Classifier head
  761. self.classifier = nn.Sequential(
  762. nn.Linear(config.hidden_size, config.hidden_size * 2),
  763. nn.LayerNorm(config.hidden_size * 2),
  764. nn.GELU(),
  765. nn.Linear(config.hidden_size * 2, config.num_labels),
  766. )
  767. # Initialize weights and apply final processing
  768. self.post_init()
  769. @auto_docstring
  770. def forward(
  771. self,
  772. input_ids: Optional[torch.LongTensor] = None,
  773. attention_mask: Optional[torch.FloatTensor] = None,
  774. token_type_ids: Optional[torch.LongTensor] = None,
  775. pixel_values: Optional[torch.FloatTensor] = None,
  776. pixel_mask: Optional[torch.LongTensor] = None,
  777. head_mask: Optional[torch.FloatTensor] = None,
  778. inputs_embeds: Optional[torch.FloatTensor] = None,
  779. image_embeds: Optional[torch.FloatTensor] = None,
  780. labels: Optional[torch.LongTensor] = None,
  781. output_attentions: Optional[bool] = None,
  782. output_hidden_states: Optional[bool] = None,
  783. return_dict: Optional[bool] = None,
  784. ) -> Union[SequenceClassifierOutput, tuple[torch.FloatTensor]]:
  785. r"""
  786. image_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`, *optional*):
  787. Optionally, instead of passing `pixel_values`, you can choose to directly pass an embedded representation.
  788. This is useful if you want more control over how to convert `pixel_values` into patch embeddings.
  789. labels (`torch.FloatTensor` of shape `(batch_size, num_labels)`, *optional*):
  790. Labels for computing the visual question answering loss. This tensor must be either a one-hot encoding of
  791. all answers that are applicable for a given example in the batch, or a soft encoding indicating which
  792. answers are applicable, where 1.0 is the highest score.
  793. Examples:
  794. ```python
  795. >>> from transformers import ViltProcessor, ViltForQuestionAnswering
  796. >>> import requests
  797. >>> from PIL import Image
  798. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  799. >>> image = Image.open(requests.get(url, stream=True).raw)
  800. >>> text = "How many cats are there?"
  801. >>> processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
  802. >>> model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
  803. >>> # prepare inputs
  804. >>> encoding = processor(image, text, return_tensors="pt")
  805. >>> # forward pass
  806. >>> outputs = model(**encoding)
  807. >>> logits = outputs.logits
  808. >>> idx = logits.argmax(-1).item()
  809. >>> print("Predicted answer:", model.config.id2label[idx])
  810. Predicted answer: 2
  811. ```"""
  812. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  813. outputs = self.vilt(
  814. input_ids,
  815. attention_mask=attention_mask,
  816. token_type_ids=token_type_ids,
  817. pixel_values=pixel_values,
  818. pixel_mask=pixel_mask,
  819. head_mask=head_mask,
  820. inputs_embeds=inputs_embeds,
  821. image_embeds=image_embeds,
  822. output_attentions=output_attentions,
  823. output_hidden_states=output_hidden_states,
  824. return_dict=return_dict,
  825. )
  826. pooler_output = outputs.pooler_output if return_dict else outputs[1]
  827. logits = self.classifier(pooler_output)
  828. loss = None
  829. if labels is not None:
  830. # move labels to correct device to enable PP
  831. labels = labels.to(logits.device)
  832. loss = nn.functional.binary_cross_entropy_with_logits(logits, labels) * labels.shape[1]
  833. # see https://github.com/jnhwkim/ban-vqa/blob/master/train.py#L19
  834. if not return_dict:
  835. output = (logits,) + outputs[2:]
  836. return ((loss,) + output) if loss is not None else output
  837. return SequenceClassifierOutput(
  838. loss=loss,
  839. logits=logits,
  840. hidden_states=outputs.hidden_states,
  841. attentions=outputs.attentions,
  842. )
  843. @auto_docstring(
  844. custom_intro="""
  845. Vilt Model transformer with a classifier head on top (a linear layer on top of the final hidden state of the [CLS]
  846. token) for image-to-text or text-to-image retrieval, e.g. MSCOCO and F30K.
  847. """
  848. )
  849. class ViltForImageAndTextRetrieval(ViltPreTrainedModel):
  850. def __init__(self, config):
  851. super().__init__(config)
  852. self.vilt = ViltModel(config)
  853. # Classifier head
  854. self.rank_output = nn.Linear(config.hidden_size, 1)
  855. # Initialize weights and apply final processing
  856. self.post_init()
  857. @auto_docstring
  858. def forward(
  859. self,
  860. input_ids: Optional[torch.LongTensor] = None,
  861. attention_mask: Optional[torch.FloatTensor] = None,
  862. token_type_ids: Optional[torch.LongTensor] = None,
  863. pixel_values: Optional[torch.FloatTensor] = None,
  864. pixel_mask: Optional[torch.LongTensor] = None,
  865. head_mask: Optional[torch.FloatTensor] = None,
  866. inputs_embeds: Optional[torch.FloatTensor] = None,
  867. image_embeds: Optional[torch.FloatTensor] = None,
  868. labels: Optional[torch.LongTensor] = None,
  869. output_attentions: Optional[bool] = None,
  870. output_hidden_states: Optional[bool] = None,
  871. return_dict: Optional[bool] = None,
  872. ) -> Union[SequenceClassifierOutput, tuple[torch.FloatTensor]]:
  873. r"""
  874. image_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`, *optional*):
  875. Optionally, instead of passing `pixel_values`, you can choose to directly pass an embedded representation.
  876. This is useful if you want more control over how to convert `pixel_values` into patch embeddings.
  877. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  878. Labels are currently not supported.
  879. Examples:
  880. ```python
  881. >>> from transformers import ViltProcessor, ViltForImageAndTextRetrieval
  882. >>> import requests
  883. >>> from PIL import Image
  884. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  885. >>> image = Image.open(requests.get(url, stream=True).raw)
  886. >>> texts = ["An image of two cats chilling on a couch", "A football player scoring a goal"]
  887. >>> processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-coco")
  888. >>> model = ViltForImageAndTextRetrieval.from_pretrained("dandelin/vilt-b32-finetuned-coco")
  889. >>> # forward pass
  890. >>> scores = dict()
  891. >>> for text in texts:
  892. ... # prepare inputs
  893. ... encoding = processor(image, text, return_tensors="pt")
  894. ... outputs = model(**encoding)
  895. ... scores[text] = outputs.logits[0, :].item()
  896. ```"""
  897. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  898. loss = None
  899. if labels is not None:
  900. raise NotImplementedError("Training is not yet supported.")
  901. outputs = self.vilt(
  902. input_ids,
  903. attention_mask=attention_mask,
  904. token_type_ids=token_type_ids,
  905. pixel_values=pixel_values,
  906. pixel_mask=pixel_mask,
  907. head_mask=head_mask,
  908. inputs_embeds=inputs_embeds,
  909. image_embeds=image_embeds,
  910. output_attentions=output_attentions,
  911. output_hidden_states=output_hidden_states,
  912. return_dict=return_dict,
  913. )
  914. pooler_output = outputs.pooler_output if return_dict else outputs[1]
  915. logits = self.rank_output(pooler_output)
  916. if not return_dict:
  917. output = (logits,) + outputs[2:]
  918. return ((loss,) + output) if loss is not None else output
  919. return SequenceClassifierOutput(
  920. loss=loss,
  921. logits=logits,
  922. hidden_states=outputs.hidden_states,
  923. attentions=outputs.attentions,
  924. )
  925. @auto_docstring(
  926. custom_intro="""
  927. Vilt Model transformer with a classifier head on top for natural language visual reasoning, e.g. NLVR2.
  928. """
  929. )
  930. class ViltForImagesAndTextClassification(ViltPreTrainedModel):
  931. def __init__(self, config):
  932. super().__init__(config)
  933. self.num_labels = config.num_labels
  934. self.vilt = ViltModel(config)
  935. # Classifier head
  936. num_images = config.num_images
  937. self.classifier = nn.Sequential(
  938. nn.Linear(config.hidden_size * num_images, config.hidden_size * num_images),
  939. nn.LayerNorm(config.hidden_size * num_images),
  940. nn.GELU(),
  941. nn.Linear(config.hidden_size * num_images, config.num_labels),
  942. )
  943. # Initialize weights and apply final processing
  944. self.post_init()
  945. @auto_docstring
  946. def forward(
  947. self,
  948. input_ids: Optional[torch.LongTensor] = None,
  949. attention_mask: Optional[torch.FloatTensor] = None,
  950. token_type_ids: Optional[torch.LongTensor] = None,
  951. pixel_values: Optional[torch.FloatTensor] = None,
  952. pixel_mask: Optional[torch.LongTensor] = None,
  953. head_mask: Optional[torch.FloatTensor] = None,
  954. inputs_embeds: Optional[torch.FloatTensor] = None,
  955. image_embeds: Optional[torch.FloatTensor] = None,
  956. labels: Optional[torch.LongTensor] = None,
  957. output_attentions: Optional[bool] = None,
  958. output_hidden_states: Optional[bool] = None,
  959. return_dict: Optional[bool] = None,
  960. ) -> Union[ViltForImagesAndTextClassificationOutput, tuple[torch.FloatTensor]]:
  961. r"""
  962. image_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`, *optional*):
  963. Optionally, instead of passing `pixel_values`, you can choose to directly pass an embedded representation.
  964. This is useful if you want more control over how to convert `pixel_values` into patch embeddings.
  965. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  966. Binary classification labels.
  967. Examples:
  968. ```python
  969. >>> from transformers import ViltProcessor, ViltForImagesAndTextClassification
  970. >>> import requests
  971. >>> from PIL import Image
  972. >>> image1 = Image.open(requests.get("https://lil.nlp.cornell.edu/nlvr/exs/ex0_0.jpg", stream=True).raw)
  973. >>> image2 = Image.open(requests.get("https://lil.nlp.cornell.edu/nlvr/exs/ex0_1.jpg", stream=True).raw)
  974. >>> text = "The left image contains twice the number of dogs as the right image."
  975. >>> processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-nlvr2")
  976. >>> model = ViltForImagesAndTextClassification.from_pretrained("dandelin/vilt-b32-finetuned-nlvr2")
  977. >>> # prepare inputs
  978. >>> encoding = processor([image1, image2], text, return_tensors="pt")
  979. >>> # forward pass
  980. >>> outputs = model(input_ids=encoding.input_ids, pixel_values=encoding.pixel_values.unsqueeze(0))
  981. >>> logits = outputs.logits
  982. >>> idx = logits.argmax(-1).item()
  983. >>> print("Predicted answer:", model.config.id2label[idx])
  984. Predicted answer: True
  985. ```"""
  986. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  987. output_hidden_states = (
  988. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  989. )
  990. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  991. if pixel_values is not None and pixel_values.ndim == 4:
  992. # add dummy num_images dimension
  993. pixel_values = pixel_values.unsqueeze(1)
  994. if image_embeds is not None and image_embeds.ndim == 3:
  995. # add dummy num_images dimension
  996. image_embeds = image_embeds.unsqueeze(1)
  997. num_images = pixel_values.shape[1] if pixel_values is not None else None
  998. if num_images is None:
  999. num_images = image_embeds.shape[1] if image_embeds is not None else None
  1000. if num_images != self.config.num_images:
  1001. raise ValueError(
  1002. "Make sure to match the number of images in the model with the number of images in the input."
  1003. )
  1004. pooler_outputs = []
  1005. hidden_states = [] if output_hidden_states else None
  1006. attentions = [] if output_attentions else None
  1007. for i in range(num_images):
  1008. # forward every image through the model
  1009. outputs = self.vilt(
  1010. input_ids,
  1011. attention_mask=attention_mask,
  1012. token_type_ids=token_type_ids,
  1013. pixel_values=pixel_values[:, i, :, :, :] if pixel_values is not None else None,
  1014. pixel_mask=pixel_mask[:, i, :, :] if pixel_mask is not None else None,
  1015. head_mask=head_mask,
  1016. inputs_embeds=inputs_embeds,
  1017. image_embeds=image_embeds[:, i, :, :] if image_embeds is not None else None,
  1018. image_token_type_idx=i + 1,
  1019. output_attentions=output_attentions,
  1020. output_hidden_states=output_hidden_states,
  1021. return_dict=return_dict,
  1022. )
  1023. pooler_output = outputs.pooler_output if return_dict else outputs[1]
  1024. pooler_outputs.append(pooler_output)
  1025. if output_hidden_states:
  1026. hidden_states.append(outputs.hidden_states)
  1027. if output_attentions:
  1028. attentions.append(outputs.attentions)
  1029. pooled_output = torch.cat(pooler_outputs, dim=-1)
  1030. logits = self.classifier(pooled_output)
  1031. loss = None
  1032. if labels is not None:
  1033. loss_fct = CrossEntropyLoss()
  1034. # move labels to correct device to enable PP
  1035. labels = labels.to(logits.device)
  1036. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  1037. if not return_dict:
  1038. output = (logits, hidden_states, attentions)
  1039. return ((loss,) + output) if loss is not None else output
  1040. return ViltForImagesAndTextClassificationOutput(
  1041. loss=loss,
  1042. logits=logits,
  1043. hidden_states=hidden_states,
  1044. attentions=attentions,
  1045. )
  1046. @auto_docstring
  1047. class ViltForTokenClassification(ViltPreTrainedModel):
  1048. def __init__(self, config):
  1049. super().__init__(config)
  1050. self.num_labels = config.num_labels
  1051. self.vilt = ViltModel(config, add_pooling_layer=False)
  1052. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  1053. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  1054. # Initialize weights and apply final processing
  1055. self.post_init()
  1056. @auto_docstring
  1057. def forward(
  1058. self,
  1059. input_ids: Optional[torch.LongTensor] = None,
  1060. attention_mask: Optional[torch.FloatTensor] = None,
  1061. token_type_ids: Optional[torch.LongTensor] = None,
  1062. pixel_values: Optional[torch.FloatTensor] = None,
  1063. pixel_mask: Optional[torch.LongTensor] = None,
  1064. head_mask: Optional[torch.FloatTensor] = None,
  1065. inputs_embeds: Optional[torch.FloatTensor] = None,
  1066. image_embeds: Optional[torch.FloatTensor] = None,
  1067. labels: Optional[torch.LongTensor] = None,
  1068. output_attentions: Optional[bool] = None,
  1069. output_hidden_states: Optional[bool] = None,
  1070. return_dict: Optional[bool] = None,
  1071. ) -> Union[TokenClassifierOutput, tuple[torch.FloatTensor]]:
  1072. r"""
  1073. image_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`, *optional*):
  1074. Optionally, instead of passing `pixel_values`, you can choose to directly pass an embedded representation.
  1075. This is useful if you want more control over how to convert `pixel_values` into patch embeddings.
  1076. labels (`torch.LongTensor` of shape `(batch_size, text_sequence_length)`, *optional*):
  1077. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  1078. """
  1079. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1080. outputs = self.vilt(
  1081. input_ids,
  1082. attention_mask=attention_mask,
  1083. token_type_ids=token_type_ids,
  1084. pixel_values=pixel_values,
  1085. pixel_mask=pixel_mask,
  1086. head_mask=head_mask,
  1087. inputs_embeds=inputs_embeds,
  1088. image_embeds=image_embeds,
  1089. output_attentions=output_attentions,
  1090. output_hidden_states=output_hidden_states,
  1091. return_dict=return_dict,
  1092. )
  1093. sequence_output = outputs[0]
  1094. text_input_size = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
  1095. sequence_output = self.dropout(sequence_output)
  1096. logits = self.classifier(sequence_output[:, :text_input_size])
  1097. loss = None
  1098. if labels is not None:
  1099. loss_fct = CrossEntropyLoss()
  1100. # move labels to correct device to enable PP
  1101. labels = labels.to(logits.device)
  1102. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  1103. if not return_dict:
  1104. output = (logits,) + outputs[2:]
  1105. return ((loss,) + output) if loss is not None else output
  1106. return TokenClassifierOutput(
  1107. loss=loss,
  1108. logits=logits,
  1109. hidden_states=outputs.hidden_states,
  1110. attentions=outputs.attentions,
  1111. )
  1112. __all__ = [
  1113. "ViltForImageAndTextRetrieval",
  1114. "ViltForImagesAndTextClassification",
  1115. "ViltForTokenClassification",
  1116. "ViltForMaskedLM",
  1117. "ViltForQuestionAnswering",
  1118. "ViltLayer",
  1119. "ViltModel",
  1120. "ViltPreTrainedModel",
  1121. ]