modeling_lxmert.py 62 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415
  1. # coding=utf-8
  2. # Copyright 2018 Hao Tan, Mohit Bansal, and the HuggingFace team
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch LXMERT model."""
  16. import math
  17. import os
  18. import warnings
  19. from dataclasses import dataclass
  20. from typing import Optional, Union
  21. import torch
  22. from torch import nn
  23. from torch.nn import CrossEntropyLoss, SmoothL1Loss
  24. from ...activations import ACT2FN, gelu
  25. from ...modeling_utils import PreTrainedModel
  26. from ...utils import ModelOutput, auto_docstring, logging
  27. from .configuration_lxmert import LxmertConfig
  28. logger = logging.get_logger(__name__)
  29. class GeLU(nn.Module):
  30. def __init__(self):
  31. super().__init__()
  32. def forward(self, x):
  33. return gelu(x)
  34. @dataclass
  35. @auto_docstring(
  36. custom_intro="""
  37. Lxmert's outputs that contain the last hidden states, pooled outputs, and attention probabilities for the language,
  38. visual, and, cross-modality encoders. (note: the visual encoder in Lxmert is referred to as the "relation-ship"
  39. encoder")
  40. """
  41. )
  42. class LxmertModelOutput(ModelOutput):
  43. r"""
  44. language_output (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  45. Sequence of hidden-states at the output of the last layer of the language encoder.
  46. vision_output (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  47. Sequence of hidden-states at the output of the last layer of the visual encoder.
  48. pooled_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
  49. Last layer hidden-state of the first token of the sequence (classification, CLS, token) further processed
  50. by a Linear layer and a Tanh activation function. The Linear
  51. language_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  52. Tuple of `torch.FloatTensor` (one for input features + one for the output of each cross-modality layer) of
  53. shape `(batch_size, sequence_length, hidden_size)`.
  54. vision_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  55. Tuple of `torch.FloatTensor` (one for input features + one for the output of each cross-modality layer) of
  56. shape `(batch_size, sequence_length, hidden_size)`.
  57. language_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  58. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  59. sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
  60. the self-attention heads.
  61. vision_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  62. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  63. sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
  64. the self-attention heads.
  65. cross_encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  66. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  67. sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
  68. the self-attention heads.
  69. """
  70. language_output: Optional[torch.FloatTensor] = None
  71. vision_output: Optional[torch.FloatTensor] = None
  72. pooled_output: Optional[torch.FloatTensor] = None
  73. language_hidden_states: Optional[tuple[torch.FloatTensor]] = None
  74. vision_hidden_states: Optional[tuple[torch.FloatTensor]] = None
  75. language_attentions: Optional[tuple[torch.FloatTensor]] = None
  76. vision_attentions: Optional[tuple[torch.FloatTensor]] = None
  77. cross_encoder_attentions: Optional[tuple[torch.FloatTensor]] = None
  78. @dataclass
  79. @auto_docstring(
  80. custom_intro="""
  81. Output type of [`LxmertForQuestionAnswering`].
  82. """
  83. )
  84. class LxmertForQuestionAnsweringOutput(ModelOutput):
  85. r"""
  86. loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
  87. Total loss as the sum of the masked language modeling loss and the next sequence prediction
  88. (classification) loss.k.
  89. question_answering_score (`torch.FloatTensor` of shape `(batch_size, n_qa_answers)`, *optional*):
  90. Prediction scores of question answering objective (classification).
  91. language_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  92. Tuple of `torch.FloatTensor` (one for input features + one for the output of each cross-modality layer) of
  93. shape `(batch_size, sequence_length, hidden_size)`.
  94. vision_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  95. Tuple of `torch.FloatTensor` (one for input features + one for the output of each cross-modality layer) of
  96. shape `(batch_size, sequence_length, hidden_size)`.
  97. language_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  98. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  99. sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
  100. the self-attention heads.
  101. vision_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  102. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  103. sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
  104. the self-attention heads.
  105. cross_encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  106. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  107. sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
  108. the self-attention heads.
  109. """
  110. loss: Optional[torch.FloatTensor] = None
  111. question_answering_score: Optional[torch.FloatTensor] = None
  112. language_hidden_states: Optional[tuple[torch.FloatTensor]] = None
  113. vision_hidden_states: Optional[tuple[torch.FloatTensor]] = None
  114. language_attentions: Optional[tuple[torch.FloatTensor]] = None
  115. vision_attentions: Optional[tuple[torch.FloatTensor]] = None
  116. cross_encoder_attentions: Optional[tuple[torch.FloatTensor]] = None
  117. @dataclass
  118. @auto_docstring(
  119. custom_intro="""
  120. Output type of [`LxmertForPreTraining`].
  121. """
  122. )
  123. class LxmertForPreTrainingOutput(ModelOutput):
  124. r"""
  125. loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
  126. Total loss as the sum of the masked language modeling loss and the next sequence prediction
  127. (classification) loss.
  128. prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  129. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  130. cross_relationship_score (`torch.FloatTensor` of shape `(batch_size, 2)`):
  131. Prediction scores of the textual matching objective (classification) head (scores of True/False
  132. continuation before SoftMax).
  133. question_answering_score (`torch.FloatTensor` of shape `(batch_size, n_qa_answers)`):
  134. Prediction scores of question answering objective (classification).
  135. language_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  136. Tuple of `torch.FloatTensor` (one for input features + one for the output of each cross-modality layer) of
  137. shape `(batch_size, sequence_length, hidden_size)`.
  138. vision_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  139. Tuple of `torch.FloatTensor` (one for input features + one for the output of each cross-modality layer) of
  140. shape `(batch_size, sequence_length, hidden_size)`.
  141. language_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  142. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  143. sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
  144. the self-attention heads.
  145. vision_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  146. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  147. sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
  148. the self-attention heads.
  149. cross_encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  150. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  151. sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
  152. the self-attention heads.
  153. """
  154. loss: Optional[torch.FloatTensor] = None
  155. prediction_logits: Optional[torch.FloatTensor] = None
  156. cross_relationship_score: Optional[torch.FloatTensor] = None
  157. question_answering_score: Optional[torch.FloatTensor] = None
  158. language_hidden_states: Optional[tuple[torch.FloatTensor]] = None
  159. vision_hidden_states: Optional[tuple[torch.FloatTensor]] = None
  160. language_attentions: Optional[tuple[torch.FloatTensor]] = None
  161. vision_attentions: Optional[tuple[torch.FloatTensor]] = None
  162. cross_encoder_attentions: Optional[tuple[torch.FloatTensor]] = None
  163. def load_tf_weights_in_lxmert(model, config, tf_checkpoint_path):
  164. """Load tf checkpoints in a pytorch model."""
  165. try:
  166. import re
  167. import numpy as np
  168. import tensorflow as tf
  169. except ImportError:
  170. logger.error(
  171. "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
  172. "https://www.tensorflow.org/install/ for installation instructions."
  173. )
  174. raise
  175. tf_path = os.path.abspath(tf_checkpoint_path)
  176. logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
  177. # Load weights from TF model
  178. init_vars = tf.train.list_variables(tf_path)
  179. names = []
  180. arrays = []
  181. for name, shape in init_vars:
  182. logger.info(f"Loading TF weight {name} with shape {shape}")
  183. array = tf.train.load_variable(tf_path, name)
  184. names.append(name)
  185. arrays.append(array)
  186. for name, array in zip(names, arrays):
  187. name = name.split("/")
  188. # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
  189. # which are not required for using pretrained model
  190. if any(
  191. n
  192. in [
  193. "adam_v",
  194. "adam_m",
  195. "AdamWeightDecayOptimizer",
  196. "AdamWeightDecayOptimizer_1",
  197. "global_step",
  198. ]
  199. for n in name
  200. ):
  201. logger.info(f"Skipping {'/'.join(name)}")
  202. continue
  203. pointer = model
  204. for m_name in name:
  205. if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
  206. scope_names = re.split(r"_(\d+)", m_name)
  207. else:
  208. scope_names = [m_name]
  209. if scope_names[0] == "kernel" or scope_names[0] == "gamma":
  210. pointer = getattr(pointer, "weight")
  211. elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
  212. pointer = getattr(pointer, "bias")
  213. elif scope_names[0] == "output_weights":
  214. pointer = getattr(pointer, "weight")
  215. elif scope_names[0] == "squad":
  216. pointer = getattr(pointer, "classifier")
  217. else:
  218. try:
  219. pointer = getattr(pointer, scope_names[0])
  220. except AttributeError:
  221. logger.info(f"Skipping {'/'.join(name)}")
  222. continue
  223. if len(scope_names) >= 2:
  224. num = int(scope_names[1])
  225. pointer = pointer[num]
  226. if m_name[-11:] == "_embeddings":
  227. pointer = getattr(pointer, "weight")
  228. elif m_name == "kernel":
  229. array = np.transpose(array)
  230. try:
  231. assert pointer.shape == array.shape
  232. except AssertionError as e:
  233. e.args += (pointer.shape, array.shape)
  234. raise
  235. logger.info(f"Initialize PyTorch weight {name}")
  236. pointer.data = torch.from_numpy(array)
  237. return model
  238. class LxmertEmbeddings(nn.Module):
  239. """Construct the embeddings from word, position and token_type embeddings."""
  240. def __init__(self, config):
  241. super().__init__()
  242. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)
  243. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size, padding_idx=0)
  244. self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size, padding_idx=0)
  245. # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
  246. # any TensorFlow checkpoint file
  247. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=1e-12)
  248. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  249. def forward(self, input_ids, token_type_ids=None, inputs_embeds=None):
  250. if input_ids is not None:
  251. input_shape = input_ids.size()
  252. device = input_ids.device
  253. else:
  254. input_shape = inputs_embeds.size()[:-1]
  255. device = inputs_embeds.device
  256. seq_length = input_shape[1]
  257. position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
  258. position_ids = position_ids.unsqueeze(0).expand(input_shape)
  259. if token_type_ids is None:
  260. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
  261. if inputs_embeds is None:
  262. inputs_embeds = self.word_embeddings(input_ids)
  263. position_embeddings = self.position_embeddings(position_ids)
  264. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  265. embeddings = inputs_embeds + position_embeddings + token_type_embeddings
  266. embeddings = self.LayerNorm(embeddings)
  267. embeddings = self.dropout(embeddings)
  268. return embeddings
  269. class LxmertAttention(nn.Module):
  270. def __init__(self, config, ctx_dim=None):
  271. super().__init__()
  272. if config.hidden_size % config.num_attention_heads != 0:
  273. raise ValueError(
  274. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  275. f"heads ({config.num_attention_heads})"
  276. )
  277. self.num_attention_heads = config.num_attention_heads
  278. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  279. self.head_size = self.num_attention_heads * self.attention_head_size
  280. # visual_dim = 2048
  281. if ctx_dim is None:
  282. ctx_dim = config.hidden_size
  283. self.query = nn.Linear(config.hidden_size, self.head_size)
  284. self.key = nn.Linear(ctx_dim, self.head_size)
  285. self.value = nn.Linear(ctx_dim, self.head_size)
  286. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  287. def forward(self, hidden_states, context, attention_mask=None, output_attentions=False):
  288. batch_size, seq_length, _ = hidden_states.shape
  289. query_layer = (
  290. self.query(hidden_states)
  291. .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
  292. .transpose(1, 2)
  293. )
  294. key_layer = (
  295. self.key(context).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2)
  296. )
  297. value_layer = (
  298. self.value(context)
  299. .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
  300. .transpose(1, 2)
  301. )
  302. # Take the dot product between "query" and "key" to get the raw attention scores.
  303. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  304. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  305. # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
  306. if attention_mask is not None:
  307. attention_scores = attention_scores + attention_mask
  308. # Normalize the attention scores to probabilities.
  309. attention_probs = nn.functional.softmax(attention_scores, dim=-1)
  310. # This is actually dropping out entire tokens to attend to, which might
  311. # seem a bit unusual, but is taken from the original Transformer paper.
  312. attention_probs = self.dropout(attention_probs)
  313. context_layer = torch.matmul(attention_probs, value_layer)
  314. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  315. new_context_layer_shape = context_layer.size()[:-2] + (self.head_size,)
  316. context_layer = context_layer.view(new_context_layer_shape)
  317. outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
  318. return outputs
  319. class LxmertAttentionOutput(nn.Module):
  320. def __init__(self, config):
  321. super().__init__()
  322. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  323. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=1e-12)
  324. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  325. def forward(self, hidden_states, input_tensor):
  326. hidden_states = self.dense(hidden_states)
  327. hidden_states = self.dropout(hidden_states)
  328. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  329. return hidden_states
  330. class LxmertCrossAttentionLayer(nn.Module):
  331. def __init__(self, config):
  332. super().__init__()
  333. self.att = LxmertAttention(config)
  334. self.output = LxmertAttentionOutput(config)
  335. def forward(self, input_tensor, ctx_tensor, ctx_att_mask=None, output_attentions=False):
  336. output = self.att(input_tensor, ctx_tensor, ctx_att_mask, output_attentions=output_attentions)
  337. if output_attentions:
  338. attention_probs = output[1]
  339. attention_output = self.output(output[0], input_tensor)
  340. outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
  341. return outputs
  342. class LxmertSelfAttentionLayer(nn.Module):
  343. def __init__(self, config):
  344. super().__init__()
  345. self.self = LxmertAttention(config)
  346. self.output = LxmertAttentionOutput(config)
  347. def forward(self, input_tensor, attention_mask, output_attentions=False):
  348. # Self attention attends to itself, thus keys and queries are the same (input_tensor).
  349. output = self.self(
  350. input_tensor,
  351. input_tensor,
  352. attention_mask,
  353. output_attentions=output_attentions,
  354. )
  355. if output_attentions:
  356. attention_probs = output[1]
  357. attention_output = self.output(output[0], input_tensor)
  358. outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
  359. return outputs
  360. class LxmertIntermediate(nn.Module):
  361. def __init__(self, config):
  362. super().__init__()
  363. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  364. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  365. def forward(self, hidden_states):
  366. hidden_states = self.dense(hidden_states)
  367. hidden_states = self.intermediate_act_fn(hidden_states)
  368. return hidden_states
  369. class LxmertOutput(nn.Module):
  370. def __init__(self, config):
  371. super().__init__()
  372. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  373. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=1e-12)
  374. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  375. def forward(self, hidden_states, input_tensor):
  376. hidden_states = self.dense(hidden_states)
  377. hidden_states = self.dropout(hidden_states)
  378. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  379. return hidden_states
  380. class LxmertLayer(nn.Module):
  381. def __init__(self, config):
  382. super().__init__()
  383. self.attention = LxmertSelfAttentionLayer(config)
  384. self.intermediate = LxmertIntermediate(config)
  385. self.output = LxmertOutput(config)
  386. def forward(self, hidden_states, attention_mask=None, output_attentions=False):
  387. outputs = self.attention(hidden_states, attention_mask, output_attentions=output_attentions)
  388. attention_output = outputs[0]
  389. intermediate_output = self.intermediate(attention_output)
  390. layer_output = self.output(intermediate_output, attention_output)
  391. outputs = (layer_output,) + outputs[1:] # add attentions if we output them
  392. return outputs
  393. class LxmertXLayer(nn.Module):
  394. def __init__(self, config):
  395. super().__init__()
  396. # The cross-attention Layer
  397. self.visual_attention = LxmertCrossAttentionLayer(config)
  398. # Self-attention Layers
  399. self.lang_self_att = LxmertSelfAttentionLayer(config)
  400. self.visn_self_att = LxmertSelfAttentionLayer(config)
  401. # Intermediate and Output Layers (FFNs)
  402. self.lang_inter = LxmertIntermediate(config)
  403. self.lang_output = LxmertOutput(config)
  404. self.visn_inter = LxmertIntermediate(config)
  405. self.visn_output = LxmertOutput(config)
  406. def cross_att(
  407. self,
  408. lang_input,
  409. lang_attention_mask,
  410. visual_input,
  411. visual_attention_mask,
  412. output_x_attentions=False,
  413. ):
  414. # Cross Attention
  415. lang_att_output = self.visual_attention(
  416. lang_input,
  417. visual_input,
  418. ctx_att_mask=visual_attention_mask,
  419. output_attentions=output_x_attentions,
  420. )
  421. visual_att_output = self.visual_attention(
  422. visual_input,
  423. lang_input,
  424. ctx_att_mask=lang_attention_mask,
  425. output_attentions=False,
  426. )
  427. return lang_att_output, visual_att_output
  428. def self_att(self, lang_input, lang_attention_mask, visual_input, visual_attention_mask):
  429. # Self Attention
  430. lang_att_output = self.lang_self_att(lang_input, lang_attention_mask, output_attentions=False)
  431. visual_att_output = self.visn_self_att(visual_input, visual_attention_mask, output_attentions=False)
  432. return lang_att_output[0], visual_att_output[0]
  433. def output_fc(self, lang_input, visual_input):
  434. # FC layers
  435. lang_inter_output = self.lang_inter(lang_input)
  436. visual_inter_output = self.visn_inter(visual_input)
  437. # Layer output
  438. lang_output = self.lang_output(lang_inter_output, lang_input)
  439. visual_output = self.visn_output(visual_inter_output, visual_input)
  440. return lang_output, visual_output
  441. def forward(
  442. self,
  443. lang_feats,
  444. lang_attention_mask,
  445. visual_feats,
  446. visual_attention_mask,
  447. output_attentions=False,
  448. ):
  449. lang_att_output, visual_att_output = self.cross_att(
  450. lang_input=lang_feats,
  451. lang_attention_mask=lang_attention_mask,
  452. visual_input=visual_feats,
  453. visual_attention_mask=visual_attention_mask,
  454. output_x_attentions=output_attentions,
  455. )
  456. attention_probs = lang_att_output[1:]
  457. lang_att_output, visual_att_output = self.self_att(
  458. lang_att_output[0],
  459. lang_attention_mask,
  460. visual_att_output[0],
  461. visual_attention_mask,
  462. )
  463. lang_output, visual_output = self.output_fc(lang_att_output, visual_att_output)
  464. return (
  465. (
  466. lang_output,
  467. visual_output,
  468. attention_probs[0],
  469. )
  470. if output_attentions
  471. else (lang_output, visual_output)
  472. )
  473. class LxmertVisualFeatureEncoder(nn.Module):
  474. def __init__(self, config):
  475. super().__init__()
  476. feat_dim = config.visual_feat_dim
  477. pos_dim = config.visual_pos_dim
  478. # Object feature encoding
  479. self.visn_fc = nn.Linear(feat_dim, config.hidden_size)
  480. self.visn_layer_norm = nn.LayerNorm(config.hidden_size, eps=1e-12)
  481. # Box position encoding
  482. self.box_fc = nn.Linear(pos_dim, config.hidden_size)
  483. self.box_layer_norm = nn.LayerNorm(config.hidden_size, eps=1e-12)
  484. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  485. def forward(self, visual_feats, visual_pos):
  486. x = self.visn_fc(visual_feats)
  487. x = self.visn_layer_norm(x)
  488. y = self.box_fc(visual_pos)
  489. y = self.box_layer_norm(y)
  490. output = (x + y) / 2
  491. output = self.dropout(output)
  492. return output
  493. class LxmertEncoder(nn.Module):
  494. def __init__(self, config):
  495. super().__init__()
  496. # Obj-level image embedding layer
  497. self.visn_fc = LxmertVisualFeatureEncoder(config)
  498. self.config = config
  499. # Number of layers
  500. self.num_l_layers = config.l_layers
  501. self.num_x_layers = config.x_layers
  502. self.num_r_layers = config.r_layers
  503. # Layers
  504. # Using self.layer instead of self.l_layer to support loading BERT weights.
  505. self.layer = nn.ModuleList([LxmertLayer(config) for _ in range(self.num_l_layers)])
  506. self.x_layers = nn.ModuleList([LxmertXLayer(config) for _ in range(self.num_x_layers)])
  507. self.r_layers = nn.ModuleList([LxmertLayer(config) for _ in range(self.num_r_layers)])
  508. def forward(
  509. self,
  510. lang_feats,
  511. lang_attention_mask,
  512. visual_feats,
  513. visual_pos,
  514. visual_attention_mask=None,
  515. output_attentions=None,
  516. ):
  517. vision_hidden_states = ()
  518. language_hidden_states = ()
  519. vision_attentions = () if output_attentions or self.config.output_attentions else None
  520. language_attentions = () if output_attentions or self.config.output_attentions else None
  521. cross_encoder_attentions = () if output_attentions or self.config.output_attentions else None
  522. visual_feats = self.visn_fc(visual_feats, visual_pos)
  523. # Run language layers
  524. for layer_module in self.layer:
  525. l_outputs = layer_module(lang_feats, lang_attention_mask, output_attentions=output_attentions)
  526. lang_feats = l_outputs[0]
  527. language_hidden_states = language_hidden_states + (lang_feats,)
  528. if language_attentions is not None:
  529. language_attentions = language_attentions + (l_outputs[1],)
  530. # Run relational layers
  531. for layer_module in self.r_layers:
  532. v_outputs = layer_module(visual_feats, visual_attention_mask, output_attentions=output_attentions)
  533. visual_feats = v_outputs[0]
  534. vision_hidden_states = vision_hidden_states + (visual_feats,)
  535. if vision_attentions is not None:
  536. vision_attentions = vision_attentions + (v_outputs[1],)
  537. # Run cross-modality layers
  538. for layer_module in self.x_layers:
  539. x_outputs = layer_module(
  540. lang_feats,
  541. lang_attention_mask,
  542. visual_feats,
  543. visual_attention_mask,
  544. output_attentions=output_attentions,
  545. )
  546. lang_feats, visual_feats = x_outputs[:2]
  547. vision_hidden_states = vision_hidden_states + (visual_feats,)
  548. language_hidden_states = language_hidden_states + (lang_feats,)
  549. if cross_encoder_attentions is not None:
  550. cross_encoder_attentions = cross_encoder_attentions + (x_outputs[2],)
  551. visual_encoder_outputs = (
  552. vision_hidden_states,
  553. vision_attentions if output_attentions else None,
  554. )
  555. lang_encoder_outputs = (
  556. language_hidden_states,
  557. language_attentions if output_attentions else None,
  558. )
  559. return (
  560. visual_encoder_outputs,
  561. lang_encoder_outputs,
  562. cross_encoder_attentions if output_attentions else None,
  563. )
  564. class LxmertPooler(nn.Module):
  565. def __init__(self, config):
  566. super().__init__()
  567. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  568. self.activation = nn.Tanh()
  569. def forward(self, hidden_states):
  570. # We "pool" the model by simply taking the hidden state corresponding
  571. # to the first token.
  572. first_token_tensor = hidden_states[:, 0]
  573. pooled_output = self.dense(first_token_tensor)
  574. pooled_output = self.activation(pooled_output)
  575. return pooled_output
  576. class LxmertPredictionHeadTransform(nn.Module):
  577. def __init__(self, config):
  578. super().__init__()
  579. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  580. self.transform_act_fn = ACT2FN[config.hidden_act]
  581. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=1e-12)
  582. def forward(self, hidden_states):
  583. hidden_states = self.dense(hidden_states)
  584. hidden_states = self.transform_act_fn(hidden_states)
  585. hidden_states = self.LayerNorm(hidden_states)
  586. return hidden_states
  587. class LxmertLMPredictionHead(nn.Module):
  588. def __init__(self, config, lxmert_model_embedding_weights):
  589. super().__init__()
  590. self.transform = LxmertPredictionHeadTransform(config)
  591. # The output weights are the same as the input embeddings, but there is
  592. # an output-only bias for each token.
  593. self.decoder = nn.Linear(
  594. lxmert_model_embedding_weights.size(1),
  595. lxmert_model_embedding_weights.size(0),
  596. bias=False,
  597. )
  598. self.decoder.weight = lxmert_model_embedding_weights
  599. self.bias = nn.Parameter(torch.zeros(lxmert_model_embedding_weights.size(0)))
  600. def forward(self, hidden_states):
  601. hidden_states = self.transform(hidden_states)
  602. hidden_states = self.decoder(hidden_states) + self.bias
  603. return hidden_states
  604. class LxmertVisualAnswerHead(nn.Module):
  605. def __init__(self, config, num_labels):
  606. super().__init__()
  607. hid_dim = config.hidden_size
  608. self.logit_fc = nn.Sequential(
  609. nn.Linear(hid_dim, hid_dim * 2),
  610. GeLU(),
  611. nn.LayerNorm(hid_dim * 2, eps=1e-12),
  612. nn.Linear(hid_dim * 2, num_labels),
  613. )
  614. def forward(self, hidden_states):
  615. return self.logit_fc(hidden_states)
  616. class LxmertVisualObjHead(nn.Module):
  617. def __init__(self, config):
  618. super().__init__()
  619. self.transform = LxmertPredictionHeadTransform(config)
  620. # Decide the use of visual losses
  621. visual_losses = {}
  622. if config.visual_obj_loss:
  623. visual_losses["obj"] = {"shape": (-1,), "num": config.num_object_labels}
  624. if config.visual_attr_loss:
  625. visual_losses["attr"] = {"shape": (-1,), "num": config.num_attr_labels}
  626. if config.visual_feat_loss:
  627. visual_losses["feat"] = {
  628. "shape": (-1, config.visual_feat_dim),
  629. "num": config.visual_feat_dim,
  630. }
  631. self.visual_losses = visual_losses
  632. # The output weights are the same as the input embeddings, but there is
  633. # an output-only bias for each token.
  634. self.decoder_dict = nn.ModuleDict(
  635. {key: nn.Linear(config.hidden_size, self.visual_losses[key]["num"]) for key in self.visual_losses}
  636. )
  637. def forward(self, hidden_states):
  638. hidden_states = self.transform(hidden_states)
  639. output = {}
  640. for key in self.visual_losses:
  641. output[key] = self.decoder_dict[key](hidden_states)
  642. return output
  643. class LxmertPreTrainingHeads(nn.Module):
  644. def __init__(self, config, lxmert_model_embedding_weights):
  645. super().__init__()
  646. self.predictions = LxmertLMPredictionHead(config, lxmert_model_embedding_weights)
  647. self.seq_relationship = nn.Linear(config.hidden_size, 2)
  648. def forward(self, sequence_output, pooled_output):
  649. prediction_scores = self.predictions(sequence_output)
  650. seq_relationship_score = self.seq_relationship(pooled_output)
  651. return prediction_scores, seq_relationship_score
  652. @auto_docstring
  653. class LxmertPreTrainedModel(PreTrainedModel):
  654. config: LxmertConfig
  655. load_tf_weights = load_tf_weights_in_lxmert
  656. base_model_prefix = "lxmert"
  657. _supports_param_buffer_assignment = False
  658. def _init_weights(self, module):
  659. """Initialize the weights"""
  660. if isinstance(module, nn.Linear):
  661. # Slightly different from the TF version which uses truncated_normal for initialization
  662. # cf https://github.com/pytorch/pytorch/pull/5617
  663. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  664. if module.bias is not None:
  665. module.bias.data.zero_()
  666. elif isinstance(module, nn.Embedding):
  667. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  668. if module.padding_idx is not None:
  669. module.weight.data[module.padding_idx].zero_()
  670. elif isinstance(module, nn.LayerNorm):
  671. module.bias.data.zero_()
  672. module.weight.data.fill_(1.0)
  673. elif isinstance(module, LxmertLMPredictionHead):
  674. module.bias.data.zero_()
  675. @auto_docstring
  676. class LxmertModel(LxmertPreTrainedModel):
  677. def __init__(self, config):
  678. super().__init__(config)
  679. self.embeddings = LxmertEmbeddings(config)
  680. self.encoder = LxmertEncoder(config)
  681. self.pooler = LxmertPooler(config)
  682. # Initialize weights and apply final processing
  683. self.post_init()
  684. def get_input_embeddings(self):
  685. return self.embeddings.word_embeddings
  686. def set_input_embeddings(self, new_embeddings):
  687. self.embeddings.word_embeddings = new_embeddings
  688. @auto_docstring
  689. def forward(
  690. self,
  691. input_ids: Optional[torch.LongTensor] = None,
  692. visual_feats: Optional[torch.FloatTensor] = None,
  693. visual_pos: Optional[torch.FloatTensor] = None,
  694. attention_mask: Optional[torch.FloatTensor] = None,
  695. visual_attention_mask: Optional[torch.FloatTensor] = None,
  696. token_type_ids: Optional[torch.LongTensor] = None,
  697. inputs_embeds: Optional[torch.FloatTensor] = None,
  698. output_attentions: Optional[bool] = None,
  699. output_hidden_states: Optional[bool] = None,
  700. return_dict: Optional[bool] = None,
  701. ) -> Union[LxmertModelOutput, tuple[torch.FloatTensor]]:
  702. r"""
  703. visual_feats (`torch.FloatTensor` of shape `(batch_size, num_visual_features, visual_feat_dim)`):
  704. This input represents visual features. They ROI pooled object features from bounding boxes using a
  705. faster-RCNN model)
  706. These are currently not provided by the transformers library.
  707. visual_pos (`torch.FloatTensor` of shape `(batch_size, num_visual_features, visual_pos_dim)`):
  708. This input represents spatial features corresponding to their relative (via index) visual features. The
  709. pre-trained LXMERT model expects these spatial features to be normalized bounding boxes on a scale of 0 to
  710. 1.
  711. These are currently not provided by the transformers library.
  712. visual_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
  713. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  714. - 1 for tokens that are **not masked**,
  715. - 0 for tokens that are **masked**.
  716. [What are attention masks?](../glossary#attention-mask)
  717. """
  718. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  719. output_hidden_states = (
  720. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  721. )
  722. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  723. if input_ids is not None and inputs_embeds is not None:
  724. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  725. elif input_ids is not None:
  726. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  727. input_shape = input_ids.size()
  728. elif inputs_embeds is not None:
  729. input_shape = inputs_embeds.size()[:-1]
  730. else:
  731. raise ValueError("You have to specify either input_ids or inputs_embeds")
  732. if visual_feats is None:
  733. raise ValueError("`visual_feats` cannot be `None`")
  734. if visual_pos is None:
  735. raise ValueError("`visual_pos` cannot be `None`")
  736. device = input_ids.device if input_ids is not None else inputs_embeds.device
  737. if attention_mask is None:
  738. attention_mask = torch.ones(input_shape, device=device)
  739. if token_type_ids is None:
  740. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
  741. # We create a 3D attention mask from a 2D tensor mask.
  742. # Sizes are [batch_size, 1, 1, to_seq_length]
  743. # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
  744. # this attention mask is more simple than the triangular masking of causal attention
  745. # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
  746. extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
  747. # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
  748. # masked positions, this operation will create a tensor which is 0.0 for
  749. # positions we want to attend and the dtype's smallest value for masked positions.
  750. # Since we are adding it to the raw scores before the softmax, this is
  751. # effectively the same as removing these entirely.
  752. extended_attention_mask = extended_attention_mask.to(dtype=self.dtype)
  753. extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(self.dtype).min
  754. # Process the visual attention mask
  755. if visual_attention_mask is not None:
  756. extended_visual_attention_mask = visual_attention_mask.unsqueeze(1).unsqueeze(2)
  757. extended_visual_attention_mask = extended_visual_attention_mask.to(dtype=self.dtype)
  758. extended_visual_attention_mask = (1.0 - extended_visual_attention_mask) * torch.finfo(self.dtype).min
  759. else:
  760. extended_visual_attention_mask = None
  761. # Positional Word Embeddings
  762. embedding_output = self.embeddings(input_ids, token_type_ids, inputs_embeds)
  763. # Run Lxmert encoder
  764. encoder_outputs = self.encoder(
  765. embedding_output,
  766. extended_attention_mask,
  767. visual_feats=visual_feats,
  768. visual_pos=visual_pos,
  769. visual_attention_mask=extended_visual_attention_mask,
  770. output_attentions=output_attentions,
  771. )
  772. visual_encoder_outputs, lang_encoder_outputs = encoder_outputs[:2]
  773. vision_hidden_states = visual_encoder_outputs[0]
  774. language_hidden_states = lang_encoder_outputs[0]
  775. all_attentions = ()
  776. if output_attentions:
  777. language_attentions = lang_encoder_outputs[1]
  778. vision_attentions = visual_encoder_outputs[1]
  779. cross_encoder_attentions = encoder_outputs[2]
  780. all_attentions = (
  781. language_attentions,
  782. vision_attentions,
  783. cross_encoder_attentions,
  784. )
  785. hidden_states = (language_hidden_states, vision_hidden_states) if output_hidden_states else ()
  786. visual_output = vision_hidden_states[-1]
  787. lang_output = language_hidden_states[-1]
  788. pooled_output = self.pooler(lang_output)
  789. if not return_dict:
  790. return (lang_output, visual_output, pooled_output) + hidden_states + all_attentions
  791. return LxmertModelOutput(
  792. pooled_output=pooled_output,
  793. language_output=lang_output,
  794. vision_output=visual_output,
  795. language_hidden_states=language_hidden_states if output_hidden_states else None,
  796. vision_hidden_states=vision_hidden_states if output_hidden_states else None,
  797. language_attentions=language_attentions if output_attentions else None,
  798. vision_attentions=vision_attentions if output_attentions else None,
  799. cross_encoder_attentions=cross_encoder_attentions if output_attentions else None,
  800. )
  801. @auto_docstring
  802. class LxmertForPreTraining(LxmertPreTrainedModel):
  803. _tied_weights_keys = ["cls.predictions.decoder.weight"]
  804. def __init__(self, config):
  805. super().__init__(config)
  806. # Configuration
  807. self.config = config
  808. self.num_qa_labels = config.num_qa_labels
  809. self.visual_loss_normalizer = config.visual_loss_normalizer
  810. # Use of pretraining tasks
  811. self.task_mask_lm = config.task_mask_lm
  812. self.task_obj_predict = config.task_obj_predict
  813. self.task_matched = config.task_matched
  814. self.task_qa = config.task_qa
  815. # Lxmert backbone
  816. self.lxmert = LxmertModel(config)
  817. # Pre-training heads
  818. self.cls = LxmertPreTrainingHeads(config, self.lxmert.embeddings.word_embeddings.weight)
  819. if self.task_obj_predict:
  820. self.obj_predict_head = LxmertVisualObjHead(config)
  821. if self.task_qa:
  822. self.answer_head = LxmertVisualAnswerHead(config, self.num_qa_labels)
  823. # Weight initialization
  824. # Initialize weights and apply final processing
  825. self.post_init()
  826. # Loss functions
  827. self.loss_fcts = {
  828. "l2": SmoothL1Loss(reduction="none"),
  829. "visual_ce": CrossEntropyLoss(reduction="none"),
  830. "ce": CrossEntropyLoss(),
  831. }
  832. visual_losses = {}
  833. if config.visual_obj_loss:
  834. visual_losses["obj"] = {
  835. "shape": (-1,),
  836. "num": config.num_object_labels,
  837. "loss": "visual_ce",
  838. }
  839. if config.visual_attr_loss:
  840. visual_losses["attr"] = {
  841. "shape": (-1,),
  842. "num": config.num_attr_labels,
  843. "loss": "visual_ce",
  844. }
  845. if config.visual_feat_loss:
  846. visual_losses["feat"] = {
  847. "shape": (-1, config.visual_feat_dim),
  848. "num": config.visual_feat_dim,
  849. "loss": "l2",
  850. }
  851. self.visual_losses = visual_losses
  852. def _tie_weights(self):
  853. self.cls.predictions.decoder.weight = self.lxmert.embeddings.word_embeddings.weight
  854. def resize_token_embeddings(
  855. self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True
  856. ) -> nn.Embedding:
  857. # Adding the following steps to resize bias to match the shape of resized embeddings
  858. new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
  859. self.cls.predictions.bias = self._resize_bias(self.cls.predictions.bias, new_num_tokens)
  860. return new_embeddings
  861. def _resize_bias(self, bias, new_num_tokens: int):
  862. old_num_tokens = bias.shape[0]
  863. if new_num_tokens <= old_num_tokens:
  864. new_bias = bias[:new_num_tokens]
  865. else:
  866. extra_bias = torch.zeros(new_num_tokens - old_num_tokens, device=bias.device)
  867. new_bias = torch.cat([bias, extra_bias])
  868. new_bias = nn.Parameter(new_bias)
  869. return new_bias
  870. def resize_num_qa_labels(self, num_labels):
  871. """
  872. Build a resized question answering linear layer Module from a provided new linear layer. Increasing the size
  873. will add newly initialized weights. Reducing the size will remove weights from the end
  874. Args:
  875. num_labels (`int`, *optional*):
  876. New number of labels in the linear layer weight matrix. Increasing the size will add newly initialized
  877. weights at the end. Reducing the size will remove weights from the end. If not provided or `None`, just
  878. returns a pointer to the qa labels ``torch.nn.Linear``` module of the model without doing anything.
  879. Return:
  880. `torch.nn.Linear`: Pointer to the resized Linear layer or the old Linear layer
  881. """
  882. cur_qa_logit_layer = self.get_qa_logit_layer()
  883. if num_labels is None or cur_qa_logit_layer is None:
  884. return
  885. new_qa_logit_layer = self._resize_qa_labels(num_labels)
  886. self.config.num_qa_labels = num_labels
  887. self.num_qa_labels = num_labels
  888. return new_qa_logit_layer
  889. def _resize_qa_labels(self, num_labels):
  890. cur_qa_logit_layer = self.get_qa_logit_layer()
  891. new_qa_logit_layer = self._get_resized_qa_labels(cur_qa_logit_layer, num_labels)
  892. self._set_qa_logit_layer(new_qa_logit_layer)
  893. return self.get_qa_logit_layer()
  894. def get_qa_logit_layer(self) -> nn.Module:
  895. """
  896. Returns the linear layer that produces question answering logits.
  897. Returns:
  898. `nn.Module`: A torch module mapping the question answering prediction hidden states or `None` if LXMERT
  899. does not have a visual answering head.
  900. """
  901. if hasattr(self, "answer_head"):
  902. return self.answer_head.logit_fc[-1]
  903. def _set_qa_logit_layer(self, qa_logit_layer):
  904. self.answer_head.logit_fc[-1] = qa_logit_layer
  905. def _get_resized_qa_labels(self, cur_qa_logit_layer, num_labels):
  906. if num_labels is None:
  907. return cur_qa_logit_layer
  908. cur_qa_labels, hidden_dim = cur_qa_logit_layer.weight.size()
  909. if cur_qa_labels == num_labels:
  910. return cur_qa_logit_layer
  911. # Build new linear output
  912. if getattr(cur_qa_logit_layer, "bias", None) is not None:
  913. new_qa_logit_layer = nn.Linear(hidden_dim, num_labels)
  914. else:
  915. new_qa_logit_layer = nn.Linear(hidden_dim, num_labels, bias=False)
  916. new_qa_logit_layer.to(cur_qa_logit_layer.weight.device)
  917. # initialize all new labels
  918. self._init_weights(new_qa_logit_layer)
  919. # Copy labels from the previous weights
  920. num_labels_to_copy = min(cur_qa_labels, num_labels)
  921. new_qa_logit_layer.weight.data[:num_labels_to_copy, :] = cur_qa_logit_layer.weight.data[:num_labels_to_copy, :]
  922. if getattr(cur_qa_logit_layer, "bias", None) is not None:
  923. new_qa_logit_layer.bias.data[:num_labels_to_copy] = cur_qa_logit_layer.bias.data[:num_labels_to_copy]
  924. return new_qa_logit_layer
  925. @auto_docstring
  926. def forward(
  927. self,
  928. input_ids: Optional[torch.LongTensor] = None,
  929. visual_feats: Optional[torch.FloatTensor] = None,
  930. visual_pos: Optional[torch.FloatTensor] = None,
  931. attention_mask: Optional[torch.FloatTensor] = None,
  932. visual_attention_mask: Optional[torch.FloatTensor] = None,
  933. token_type_ids: Optional[torch.LongTensor] = None,
  934. inputs_embeds: Optional[torch.FloatTensor] = None,
  935. labels: Optional[torch.LongTensor] = None,
  936. obj_labels: Optional[dict[str, tuple[torch.FloatTensor, torch.FloatTensor]]] = None,
  937. matched_label: Optional[torch.LongTensor] = None,
  938. ans: Optional[torch.Tensor] = None,
  939. output_attentions: Optional[bool] = None,
  940. output_hidden_states: Optional[bool] = None,
  941. return_dict: Optional[bool] = None,
  942. **kwargs,
  943. ) -> Union[LxmertForPreTrainingOutput, tuple[torch.FloatTensor]]:
  944. r"""
  945. visual_feats (`torch.FloatTensor` of shape `(batch_size, num_visual_features, visual_feat_dim)`):
  946. This input represents visual features. They ROI pooled object features from bounding boxes using a
  947. faster-RCNN model)
  948. These are currently not provided by the transformers library.
  949. visual_pos (`torch.FloatTensor` of shape `(batch_size, num_visual_features, visual_pos_dim)`):
  950. This input represents spatial features corresponding to their relative (via index) visual features. The
  951. pre-trained LXMERT model expects these spatial features to be normalized bounding boxes on a scale of 0 to
  952. 1.
  953. These are currently not provided by the transformers library.
  954. visual_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
  955. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  956. - 1 for tokens that are **not masked**,
  957. - 0 for tokens that are **masked**.
  958. [What are attention masks?](../glossary#attention-mask)
  959. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  960. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  961. config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
  962. loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  963. obj_labels (`dict[Str: tuple[Torch.FloatTensor, Torch.FloatTensor]]`, *optional*):
  964. each key is named after each one of the visual losses and each element of the tuple is of the shape
  965. `(batch_size, num_features)` and `(batch_size, num_features, visual_feature_dim)` for each the label id and
  966. the label score respectively
  967. matched_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  968. Labels for computing the whether or not the text input matches the image (classification) loss. Input
  969. should be a sequence pair (see `input_ids` docstring) Indices should be in `[0, 1]`:
  970. - 0 indicates that the sentence does not match the image,
  971. - 1 indicates that the sentence does match the image.
  972. ans (`Torch.Tensor` of shape `(batch_size)`, *optional*):
  973. a one hot representation hof the correct answer *optional*
  974. """
  975. if "masked_lm_labels" in kwargs:
  976. warnings.warn(
  977. "The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels`"
  978. " instead.",
  979. FutureWarning,
  980. )
  981. labels = kwargs.pop("masked_lm_labels")
  982. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  983. device = input_ids.device if input_ids is not None else inputs_embeds.device
  984. lxmert_output = self.lxmert(
  985. input_ids=input_ids,
  986. visual_feats=visual_feats,
  987. visual_pos=visual_pos,
  988. token_type_ids=token_type_ids,
  989. attention_mask=attention_mask,
  990. visual_attention_mask=visual_attention_mask,
  991. inputs_embeds=inputs_embeds,
  992. output_hidden_states=output_hidden_states,
  993. output_attentions=output_attentions,
  994. return_dict=return_dict,
  995. )
  996. lang_output, visual_output, pooled_output = (
  997. lxmert_output[0],
  998. lxmert_output[1],
  999. lxmert_output[2],
  1000. )
  1001. lang_prediction_scores, cross_relationship_score = self.cls(lang_output, pooled_output)
  1002. if self.task_qa:
  1003. answer_score = self.answer_head(pooled_output)
  1004. else:
  1005. answer_score = pooled_output[0][0]
  1006. total_loss = (
  1007. None
  1008. if (labels is None and matched_label is None and obj_labels is None and ans is None)
  1009. else torch.tensor(0.0, device=device)
  1010. )
  1011. if labels is not None and self.task_mask_lm:
  1012. masked_lm_loss = self.loss_fcts["ce"](
  1013. lang_prediction_scores.view(-1, self.config.vocab_size),
  1014. labels.view(-1),
  1015. )
  1016. total_loss += masked_lm_loss
  1017. if matched_label is not None and self.task_matched:
  1018. matched_loss = self.loss_fcts["ce"](cross_relationship_score.view(-1, 2), matched_label.view(-1))
  1019. total_loss += matched_loss
  1020. if obj_labels is not None and self.task_obj_predict:
  1021. total_visual_loss = torch.tensor(0.0, device=input_ids.device)
  1022. visual_prediction_scores_dict = self.obj_predict_head(visual_output)
  1023. for key, key_info in self.visual_losses.items():
  1024. label, mask_conf = obj_labels[key]
  1025. output_dim = key_info["num"]
  1026. loss_fct_name = key_info["loss"]
  1027. label_shape = key_info["shape"]
  1028. weight = self.visual_loss_normalizer
  1029. visual_loss_fct = self.loss_fcts[loss_fct_name]
  1030. visual_prediction_scores = visual_prediction_scores_dict[key]
  1031. visual_loss = visual_loss_fct(
  1032. visual_prediction_scores.view(-1, output_dim),
  1033. label.view(label_shape),
  1034. )
  1035. if visual_loss.dim() > 1: # Regression Losses
  1036. visual_loss = visual_loss.mean(1)
  1037. visual_loss = (visual_loss * mask_conf.view(-1)).mean() * weight
  1038. total_visual_loss += visual_loss
  1039. total_loss += total_visual_loss
  1040. if ans is not None and self.task_qa:
  1041. answer_loss = self.loss_fcts["ce"](answer_score.view(-1, self.num_qa_labels), ans.view(-1))
  1042. total_loss += answer_loss
  1043. if not return_dict:
  1044. output = (
  1045. lang_prediction_scores,
  1046. cross_relationship_score,
  1047. answer_score,
  1048. ) + lxmert_output[3:]
  1049. return ((total_loss,) + output) if total_loss is not None else output
  1050. return LxmertForPreTrainingOutput(
  1051. loss=total_loss,
  1052. prediction_logits=lang_prediction_scores,
  1053. cross_relationship_score=cross_relationship_score,
  1054. question_answering_score=answer_score,
  1055. language_hidden_states=lxmert_output.language_hidden_states,
  1056. vision_hidden_states=lxmert_output.vision_hidden_states,
  1057. language_attentions=lxmert_output.language_attentions,
  1058. vision_attentions=lxmert_output.vision_attentions,
  1059. cross_encoder_attentions=lxmert_output.cross_encoder_attentions,
  1060. )
  1061. @auto_docstring(
  1062. custom_intro="""
  1063. Lxmert Model with a visual-answering head on top for downstream QA tasks
  1064. """
  1065. )
  1066. class LxmertForQuestionAnswering(LxmertPreTrainedModel):
  1067. def __init__(self, config):
  1068. super().__init__(config)
  1069. # Configuration
  1070. self.config = config
  1071. self.num_qa_labels = config.num_qa_labels
  1072. self.visual_loss_normalizer = config.visual_loss_normalizer
  1073. # Lxmert backbone
  1074. self.lxmert = LxmertModel(config)
  1075. self.answer_head = LxmertVisualAnswerHead(config, self.num_qa_labels)
  1076. # Weight initialization
  1077. # Initialize weights and apply final processing
  1078. self.post_init()
  1079. # Loss function
  1080. self.loss = CrossEntropyLoss()
  1081. def resize_num_qa_labels(self, num_labels):
  1082. """
  1083. Build a resized question answering linear layer Module from a provided new linear layer. Increasing the size
  1084. will add newly initialized weights. Reducing the size will remove weights from the end
  1085. Args:
  1086. num_labels (`int`, *optional*):
  1087. New number of labels in the linear layer weight matrix. Increasing the size will add newly initialized
  1088. weights at the end. Reducing the size will remove weights from the end. If not provided or `None`, just
  1089. returns a pointer to the qa labels ``torch.nn.Linear``` module of the model without doing anything.
  1090. Return:
  1091. `torch.nn.Linear`: Pointer to the resized Linear layer or the old Linear layer
  1092. """
  1093. cur_qa_logit_layer = self.get_qa_logit_layer()
  1094. if num_labels is None or cur_qa_logit_layer is None:
  1095. return
  1096. new_qa_logit_layer = self._resize_qa_labels(num_labels)
  1097. self.config.num_qa_labels = num_labels
  1098. self.num_qa_labels = num_labels
  1099. return new_qa_logit_layer
  1100. def _resize_qa_labels(self, num_labels):
  1101. cur_qa_logit_layer = self.get_qa_logit_layer()
  1102. new_qa_logit_layer = self._get_resized_qa_labels(cur_qa_logit_layer, num_labels)
  1103. self._set_qa_logit_layer(new_qa_logit_layer)
  1104. return self.get_qa_logit_layer()
  1105. def get_qa_logit_layer(self) -> nn.Module:
  1106. """
  1107. Returns the linear layer that produces question answering logits
  1108. Returns:
  1109. `nn.Module`: A torch module mapping the question answering prediction hidden states. `None`: A NoneType
  1110. object if Lxmert does not have the visual answering head.
  1111. """
  1112. if hasattr(self, "answer_head"):
  1113. return self.answer_head.logit_fc[-1]
  1114. def _set_qa_logit_layer(self, qa_logit_layer):
  1115. self.answer_head.logit_fc[-1] = qa_logit_layer
  1116. def _get_resized_qa_labels(self, cur_qa_logit_layer, num_labels):
  1117. if num_labels is None:
  1118. return cur_qa_logit_layer
  1119. cur_qa_labels, hidden_dim = cur_qa_logit_layer.weight.size()
  1120. if cur_qa_labels == num_labels:
  1121. return cur_qa_logit_layer
  1122. # Build new linear output
  1123. if getattr(cur_qa_logit_layer, "bias", None) is not None:
  1124. new_qa_logit_layer = nn.Linear(hidden_dim, num_labels)
  1125. else:
  1126. new_qa_logit_layer = nn.Linear(hidden_dim, num_labels, bias=False)
  1127. new_qa_logit_layer.to(cur_qa_logit_layer.weight.device)
  1128. # initialize all new labels
  1129. self._init_weights(new_qa_logit_layer)
  1130. # Copy labels from the previous weights
  1131. num_labels_to_copy = min(cur_qa_labels, num_labels)
  1132. new_qa_logit_layer.weight.data[:num_labels_to_copy, :] = cur_qa_logit_layer.weight.data[:num_labels_to_copy, :]
  1133. if getattr(cur_qa_logit_layer, "bias", None) is not None:
  1134. new_qa_logit_layer.bias.data[:num_labels_to_copy] = cur_qa_logit_layer.bias.data[:num_labels_to_copy]
  1135. return new_qa_logit_layer
  1136. @auto_docstring
  1137. def forward(
  1138. self,
  1139. input_ids: Optional[torch.LongTensor] = None,
  1140. visual_feats: Optional[torch.FloatTensor] = None,
  1141. visual_pos: Optional[torch.FloatTensor] = None,
  1142. attention_mask: Optional[torch.FloatTensor] = None,
  1143. visual_attention_mask: Optional[torch.FloatTensor] = None,
  1144. token_type_ids: Optional[torch.LongTensor] = None,
  1145. inputs_embeds: Optional[torch.FloatTensor] = None,
  1146. labels: Optional[torch.Tensor] = None,
  1147. output_attentions: Optional[bool] = None,
  1148. output_hidden_states: Optional[bool] = None,
  1149. return_dict: Optional[bool] = None,
  1150. ) -> Union[LxmertForQuestionAnsweringOutput, tuple[torch.FloatTensor]]:
  1151. r"""
  1152. visual_feats (`torch.FloatTensor` of shape `(batch_size, num_visual_features, visual_feat_dim)`):
  1153. This input represents visual features. They ROI pooled object features from bounding boxes using a
  1154. faster-RCNN model)
  1155. These are currently not provided by the transformers library.
  1156. visual_pos (`torch.FloatTensor` of shape `(batch_size, num_visual_features, visual_pos_dim)`):
  1157. This input represents spatial features corresponding to their relative (via index) visual features. The
  1158. pre-trained LXMERT model expects these spatial features to be normalized bounding boxes on a scale of 0 to
  1159. 1.
  1160. These are currently not provided by the transformers library.
  1161. visual_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1162. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  1163. - 1 for tokens that are **not masked**,
  1164. - 0 for tokens that are **masked**.
  1165. [What are attention masks?](../glossary#attention-mask)
  1166. labels (`Torch.Tensor` of shape `(batch_size)`, *optional*):
  1167. A one-hot representation of the correct answer
  1168. """
  1169. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1170. lxmert_output = self.lxmert(
  1171. input_ids=input_ids,
  1172. visual_feats=visual_feats,
  1173. visual_pos=visual_pos,
  1174. token_type_ids=token_type_ids,
  1175. attention_mask=attention_mask,
  1176. visual_attention_mask=visual_attention_mask,
  1177. inputs_embeds=inputs_embeds,
  1178. output_hidden_states=output_hidden_states,
  1179. output_attentions=output_attentions,
  1180. return_dict=return_dict,
  1181. )
  1182. pooled_output = lxmert_output[2]
  1183. answer_score = self.answer_head(pooled_output)
  1184. loss = None
  1185. if labels is not None:
  1186. loss = self.loss(answer_score.view(-1, self.num_qa_labels), labels.view(-1))
  1187. if not return_dict:
  1188. output = (answer_score,) + lxmert_output[3:]
  1189. return (loss,) + output if loss is not None else output
  1190. return LxmertForQuestionAnsweringOutput(
  1191. loss=loss,
  1192. question_answering_score=answer_score,
  1193. language_hidden_states=lxmert_output.language_hidden_states,
  1194. vision_hidden_states=lxmert_output.vision_hidden_states,
  1195. language_attentions=lxmert_output.language_attentions,
  1196. vision_attentions=lxmert_output.vision_attentions,
  1197. cross_encoder_attentions=lxmert_output.cross_encoder_attentions,
  1198. )
  1199. __all__ = [
  1200. "LxmertEncoder",
  1201. "LxmertForPreTraining",
  1202. "LxmertForQuestionAnswering",
  1203. "LxmertModel",
  1204. "LxmertPreTrainedModel",
  1205. "LxmertVisualFeatureEncoder",
  1206. "LxmertXLayer",
  1207. ]