modeling_lilt.py 47 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092
  1. # coding=utf-8
  2. # Copyright 2022 The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch LiLT model."""
  16. import math
  17. from typing import Optional, Union
  18. import torch
  19. from torch import nn
  20. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  21. from ...activations import ACT2FN
  22. from ...modeling_layers import GradientCheckpointingLayer
  23. from ...modeling_outputs import (
  24. BaseModelOutput,
  25. BaseModelOutputWithPooling,
  26. QuestionAnsweringModelOutput,
  27. SequenceClassifierOutput,
  28. TokenClassifierOutput,
  29. )
  30. from ...modeling_utils import PreTrainedModel
  31. from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
  32. from ...utils import auto_docstring, logging
  33. from .configuration_lilt import LiltConfig
  34. logger = logging.get_logger(__name__)
  35. class LiltTextEmbeddings(nn.Module):
  36. def __init__(self, config):
  37. super().__init__()
  38. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  39. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
  40. self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
  41. # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
  42. # any TensorFlow checkpoint file
  43. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  44. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  45. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  46. self.register_buffer(
  47. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  48. )
  49. self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
  50. # End copy
  51. self.padding_idx = config.pad_token_id
  52. self.position_embeddings = nn.Embedding(
  53. config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
  54. )
  55. def forward(
  56. self,
  57. input_ids=None,
  58. token_type_ids=None,
  59. position_ids=None,
  60. inputs_embeds=None,
  61. ):
  62. if position_ids is None:
  63. if input_ids is not None:
  64. # Create the position ids from the input token ids. Any padded tokens remain padded.
  65. position_ids = self.create_position_ids_from_input_ids(input_ids, self.padding_idx).to(
  66. input_ids.device
  67. )
  68. else:
  69. position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
  70. if input_ids is not None:
  71. input_shape = input_ids.size()
  72. else:
  73. input_shape = inputs_embeds.size()[:-1]
  74. if token_type_ids is None:
  75. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
  76. if inputs_embeds is None:
  77. inputs_embeds = self.word_embeddings(input_ids)
  78. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  79. embeddings = inputs_embeds + token_type_embeddings
  80. if self.position_embedding_type == "absolute":
  81. position_embeddings = self.position_embeddings(position_ids)
  82. embeddings += position_embeddings
  83. embeddings = self.LayerNorm(embeddings)
  84. embeddings = self.dropout(embeddings)
  85. return embeddings, position_ids
  86. def create_position_ids_from_input_ids(self, input_ids, padding_idx):
  87. """
  88. Args:
  89. Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding
  90. symbols are ignored. This is modified from fairseq's `utils.make_positions`.
  91. x: torch.Tensor x:
  92. Returns: torch.Tensor
  93. """
  94. # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
  95. mask = input_ids.ne(padding_idx).int()
  96. incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask)) * mask
  97. return incremental_indices.long() + padding_idx
  98. def create_position_ids_from_inputs_embeds(self, inputs_embeds):
  99. """
  100. Args:
  101. We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.:
  102. inputs_embeds: torch.Tensor
  103. Returns: torch.Tensor
  104. """
  105. input_shape = inputs_embeds.size()[:-1]
  106. sequence_length = input_shape[1]
  107. position_ids = torch.arange(
  108. self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
  109. )
  110. return position_ids.unsqueeze(0).expand(input_shape)
  111. class LiltLayoutEmbeddings(nn.Module):
  112. def __init__(self, config):
  113. super().__init__()
  114. # we divide the hidden_size by 6 here as there are 6 different layout embeddings,
  115. # namely left_position, upper_position, right_position, lower_position, height, width
  116. self.x_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.hidden_size // 6)
  117. self.y_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.hidden_size // 6)
  118. self.h_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.hidden_size // 6)
  119. self.w_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.hidden_size // 6)
  120. self.padding_idx = config.pad_token_id
  121. self.box_position_embeddings = nn.Embedding(
  122. config.max_position_embeddings,
  123. config.hidden_size // config.channel_shrink_ratio,
  124. padding_idx=self.padding_idx,
  125. )
  126. self.box_linear_embeddings = nn.Linear(
  127. in_features=config.hidden_size, out_features=config.hidden_size // config.channel_shrink_ratio
  128. )
  129. self.LayerNorm = nn.LayerNorm(config.hidden_size // config.channel_shrink_ratio, eps=config.layer_norm_eps)
  130. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  131. def forward(self, bbox=None, position_ids=None):
  132. try:
  133. left_position_embeddings = self.x_position_embeddings(bbox[:, :, 0])
  134. upper_position_embeddings = self.y_position_embeddings(bbox[:, :, 1])
  135. right_position_embeddings = self.x_position_embeddings(bbox[:, :, 2])
  136. lower_position_embeddings = self.y_position_embeddings(bbox[:, :, 3])
  137. except IndexError as e:
  138. raise IndexError("The `bbox` coordinate values should be within 0-1000 range.") from e
  139. h_position_embeddings = self.h_position_embeddings(bbox[:, :, 3] - bbox[:, :, 1])
  140. w_position_embeddings = self.w_position_embeddings(bbox[:, :, 2] - bbox[:, :, 0])
  141. spatial_position_embeddings = torch.cat(
  142. [
  143. left_position_embeddings,
  144. upper_position_embeddings,
  145. right_position_embeddings,
  146. lower_position_embeddings,
  147. h_position_embeddings,
  148. w_position_embeddings,
  149. ],
  150. dim=-1,
  151. )
  152. spatial_position_embeddings = self.box_linear_embeddings(spatial_position_embeddings)
  153. box_position_embeddings = self.box_position_embeddings(position_ids)
  154. spatial_position_embeddings = spatial_position_embeddings + box_position_embeddings
  155. spatial_position_embeddings = self.LayerNorm(spatial_position_embeddings)
  156. spatial_position_embeddings = self.dropout(spatial_position_embeddings)
  157. return spatial_position_embeddings
  158. class LiltSelfAttention(nn.Module):
  159. def __init__(self, config, position_embedding_type=None, layer_idx=None):
  160. super().__init__()
  161. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  162. raise ValueError(
  163. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  164. f"heads ({config.num_attention_heads})"
  165. )
  166. self.num_attention_heads = config.num_attention_heads
  167. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  168. self.all_head_size = self.num_attention_heads * self.attention_head_size
  169. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  170. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  171. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  172. self.layout_query = nn.Linear(
  173. config.hidden_size // config.channel_shrink_ratio, self.all_head_size // config.channel_shrink_ratio
  174. )
  175. self.layout_key = nn.Linear(
  176. config.hidden_size // config.channel_shrink_ratio, self.all_head_size // config.channel_shrink_ratio
  177. )
  178. self.layout_value = nn.Linear(
  179. config.hidden_size // config.channel_shrink_ratio, self.all_head_size // config.channel_shrink_ratio
  180. )
  181. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  182. self.position_embedding_type = position_embedding_type or getattr(
  183. config, "position_embedding_type", "absolute"
  184. )
  185. if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
  186. self.max_position_embeddings = config.max_position_embeddings
  187. self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
  188. self.channel_shrink_ratio = config.channel_shrink_ratio
  189. self.layer_idx = layer_idx
  190. def transpose_for_scores(self, x, r=1):
  191. new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size // r)
  192. x = x.view(*new_x_shape)
  193. return x.permute(0, 2, 1, 3)
  194. def forward(
  195. self,
  196. hidden_states,
  197. layout_inputs,
  198. attention_mask=None,
  199. head_mask=None,
  200. output_attentions=False,
  201. ):
  202. layout_value_layer = self.transpose_for_scores(self.layout_value(layout_inputs), r=self.channel_shrink_ratio)
  203. layout_key_layer = self.transpose_for_scores(self.layout_key(layout_inputs), r=self.channel_shrink_ratio)
  204. layout_query_layer = self.transpose_for_scores(self.layout_query(layout_inputs), r=self.channel_shrink_ratio)
  205. mixed_query_layer = self.query(hidden_states)
  206. key_layer = self.transpose_for_scores(self.key(hidden_states))
  207. value_layer = self.transpose_for_scores(self.value(hidden_states))
  208. query_layer = self.transpose_for_scores(mixed_query_layer)
  209. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  210. layout_attention_scores = torch.matmul(layout_query_layer, layout_key_layer.transpose(-1, -2))
  211. if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
  212. seq_length = hidden_states.size()[1]
  213. position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
  214. position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
  215. distance = position_ids_l - position_ids_r
  216. positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
  217. positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
  218. if self.position_embedding_type == "relative_key":
  219. relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
  220. attention_scores = attention_scores + relative_position_scores
  221. elif self.position_embedding_type == "relative_key_query":
  222. relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
  223. relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
  224. attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
  225. tmp_attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  226. tmp_layout_attention_scores = layout_attention_scores / math.sqrt(
  227. self.attention_head_size // self.channel_shrink_ratio
  228. )
  229. attention_scores = tmp_attention_scores + tmp_layout_attention_scores
  230. layout_attention_scores = tmp_layout_attention_scores + tmp_attention_scores
  231. if attention_mask is not None:
  232. # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
  233. layout_attention_scores = layout_attention_scores + attention_mask
  234. # Normalize the attention scores to probabilities.
  235. layout_attention_probs = nn.Softmax(dim=-1)(layout_attention_scores)
  236. # This is actually dropping out entire tokens to attend to, which might
  237. # seem a bit unusual, but is taken from the original Transformer paper.
  238. layout_attention_probs = self.dropout(layout_attention_probs)
  239. # Mask heads if we want to
  240. if head_mask is not None:
  241. layout_attention_probs = layout_attention_probs * head_mask
  242. layout_context_layer = torch.matmul(layout_attention_probs, layout_value_layer)
  243. layout_context_layer = layout_context_layer.permute(0, 2, 1, 3).contiguous()
  244. new_context_layer_shape = layout_context_layer.size()[:-2] + (self.all_head_size // self.channel_shrink_ratio,)
  245. layout_context_layer = layout_context_layer.view(*new_context_layer_shape)
  246. if attention_mask is not None:
  247. # Apply the attention mask is (precomputed for all layers in RobertaModel forward() function)
  248. attention_scores = attention_scores + attention_mask
  249. # Normalize the attention scores to probabilities.
  250. attention_probs = nn.Softmax(dim=-1)(attention_scores)
  251. # This is actually dropping out entire tokens to attend to, which might
  252. # seem a bit unusual, but is taken from the original Transformer paper.
  253. attention_probs = self.dropout(attention_probs)
  254. # Mask heads if we want to
  255. if head_mask is not None:
  256. attention_probs = attention_probs * head_mask
  257. context_layer = torch.matmul(attention_probs, value_layer)
  258. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  259. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  260. context_layer = context_layer.view(*new_context_layer_shape)
  261. outputs = (
  262. ((context_layer, layout_context_layer), attention_probs)
  263. if output_attentions
  264. else ((context_layer, layout_context_layer),)
  265. )
  266. return outputs
  267. # Copied from transformers.models.bert.modeling_bert.BertSelfOutput
  268. class LiltSelfOutput(nn.Module):
  269. def __init__(self, config):
  270. super().__init__()
  271. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  272. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  273. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  274. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  275. hidden_states = self.dense(hidden_states)
  276. hidden_states = self.dropout(hidden_states)
  277. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  278. return hidden_states
  279. class LiltAttention(nn.Module):
  280. def __init__(self, config, position_embedding_type=None, layer_idx=None):
  281. super().__init__()
  282. self.self = LiltSelfAttention(config, position_embedding_type=position_embedding_type, layer_idx=layer_idx)
  283. self.output = LiltSelfOutput(config)
  284. self.pruned_heads = set()
  285. ori_hidden_size = config.hidden_size
  286. config.hidden_size = config.hidden_size // config.channel_shrink_ratio
  287. self.layout_output = LiltSelfOutput(config)
  288. config.hidden_size = ori_hidden_size
  289. # Copied from transformers.models.bert.modeling_bert.BertAttention.prune_heads
  290. def prune_heads(self, heads):
  291. if len(heads) == 0:
  292. return
  293. heads, index = find_pruneable_heads_and_indices(
  294. heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
  295. )
  296. # Prune linear layers
  297. self.self.query = prune_linear_layer(self.self.query, index)
  298. self.self.key = prune_linear_layer(self.self.key, index)
  299. self.self.value = prune_linear_layer(self.self.value, index)
  300. self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
  301. # Update hyper params and store pruned heads
  302. self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
  303. self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
  304. self.pruned_heads = self.pruned_heads.union(heads)
  305. def forward(
  306. self,
  307. hidden_states: torch.Tensor,
  308. layout_inputs: torch.Tensor,
  309. attention_mask: Optional[torch.FloatTensor] = None,
  310. head_mask: Optional[torch.FloatTensor] = None,
  311. output_attentions: Optional[bool] = False,
  312. ) -> tuple[torch.Tensor]:
  313. self_outputs = self.self(
  314. hidden_states,
  315. layout_inputs,
  316. attention_mask,
  317. head_mask,
  318. output_attentions,
  319. )
  320. attention_output = self.output(self_outputs[0][0], hidden_states)
  321. layout_attention_output = self.layout_output(self_outputs[0][1], layout_inputs)
  322. outputs = ((attention_output, layout_attention_output),) + self_outputs[1:] # add attentions if we output them
  323. return outputs
  324. # Copied from transformers.models.bert.modeling_bert.BertIntermediate
  325. class LiltIntermediate(nn.Module):
  326. def __init__(self, config):
  327. super().__init__()
  328. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  329. if isinstance(config.hidden_act, str):
  330. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  331. else:
  332. self.intermediate_act_fn = config.hidden_act
  333. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  334. hidden_states = self.dense(hidden_states)
  335. hidden_states = self.intermediate_act_fn(hidden_states)
  336. return hidden_states
  337. # Copied from transformers.models.bert.modeling_bert.BertOutput
  338. class LiltOutput(nn.Module):
  339. def __init__(self, config):
  340. super().__init__()
  341. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  342. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  343. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  344. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  345. hidden_states = self.dense(hidden_states)
  346. hidden_states = self.dropout(hidden_states)
  347. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  348. return hidden_states
  349. class LiltLayer(GradientCheckpointingLayer):
  350. def __init__(self, config, layer_idx=None):
  351. super().__init__()
  352. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  353. self.seq_len_dim = 1
  354. self.attention = LiltAttention(config, layer_idx=layer_idx)
  355. self.intermediate = LiltIntermediate(config)
  356. self.output = LiltOutput(config)
  357. ori_hidden_size = config.hidden_size
  358. ori_intermediate_size = config.intermediate_size
  359. config.hidden_size = config.hidden_size // config.channel_shrink_ratio
  360. config.intermediate_size = config.intermediate_size // config.channel_shrink_ratio
  361. self.layout_intermediate = LiltIntermediate(config)
  362. self.layout_output = LiltOutput(config)
  363. config.hidden_size = ori_hidden_size
  364. config.intermediate_size = ori_intermediate_size
  365. def forward(
  366. self,
  367. hidden_states: torch.Tensor,
  368. layout_inputs: torch.Tensor,
  369. attention_mask: Optional[torch.FloatTensor] = None,
  370. head_mask: Optional[torch.FloatTensor] = None,
  371. output_attentions: Optional[bool] = False,
  372. ) -> tuple[torch.Tensor]:
  373. self_attention_outputs = self.attention(
  374. hidden_states,
  375. layout_inputs,
  376. attention_mask,
  377. head_mask,
  378. output_attentions=output_attentions,
  379. )
  380. attention_output = self_attention_outputs[0][0]
  381. layout_attention_output = self_attention_outputs[0][1]
  382. outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
  383. layer_output = apply_chunking_to_forward(
  384. self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
  385. )
  386. layout_layer_output = apply_chunking_to_forward(
  387. self.layout_feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, layout_attention_output
  388. )
  389. outputs = ((layer_output, layout_layer_output),) + outputs
  390. return outputs
  391. # Copied from transformers.models.bert.modeling_bert.BertLayer.feed_forward_chunk
  392. def feed_forward_chunk(self, attention_output):
  393. intermediate_output = self.intermediate(attention_output)
  394. layer_output = self.output(intermediate_output, attention_output)
  395. return layer_output
  396. def layout_feed_forward_chunk(self, attention_output):
  397. intermediate_output = self.layout_intermediate(attention_output)
  398. layer_output = self.layout_output(intermediate_output, attention_output)
  399. return layer_output
  400. class LiltEncoder(nn.Module):
  401. # Copied from transformers.models.bert.modeling_bert.BertEncoder.__init__ with Bert->Lilt
  402. def __init__(self, config, layer_idx=None):
  403. super().__init__()
  404. self.config = config
  405. self.layer = nn.ModuleList([LiltLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)])
  406. self.gradient_checkpointing = False
  407. def forward(
  408. self,
  409. hidden_states: torch.Tensor,
  410. layout_inputs: torch.Tensor,
  411. attention_mask: Optional[torch.FloatTensor] = None,
  412. head_mask: Optional[torch.FloatTensor] = None,
  413. output_attentions: Optional[bool] = False,
  414. output_hidden_states: Optional[bool] = False,
  415. return_dict: Optional[bool] = True,
  416. ) -> Union[tuple[torch.Tensor], BaseModelOutput]:
  417. all_hidden_states = () if output_hidden_states else None
  418. all_self_attentions = () if output_attentions else None
  419. for i, layer_module in enumerate(self.layer):
  420. if output_hidden_states:
  421. all_hidden_states = all_hidden_states + (hidden_states,)
  422. layer_head_mask = head_mask[i] if head_mask is not None else None
  423. layer_outputs = layer_module(
  424. hidden_states,
  425. layout_inputs,
  426. attention_mask,
  427. layer_head_mask,
  428. output_attentions,
  429. )
  430. hidden_states = layer_outputs[0][0]
  431. layout_inputs = layer_outputs[0][1]
  432. if output_attentions:
  433. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  434. if output_hidden_states:
  435. all_hidden_states = all_hidden_states + (hidden_states,)
  436. if not return_dict:
  437. return tuple(
  438. v
  439. for v in [
  440. hidden_states,
  441. all_hidden_states,
  442. all_self_attentions,
  443. ]
  444. if v is not None
  445. )
  446. return BaseModelOutput(
  447. last_hidden_state=hidden_states,
  448. hidden_states=all_hidden_states,
  449. attentions=all_self_attentions,
  450. )
  451. # Copied from transformers.models.bert.modeling_bert.BertPooler
  452. class LiltPooler(nn.Module):
  453. def __init__(self, config):
  454. super().__init__()
  455. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  456. self.activation = nn.Tanh()
  457. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  458. # We "pool" the model by simply taking the hidden state corresponding
  459. # to the first token.
  460. first_token_tensor = hidden_states[:, 0]
  461. pooled_output = self.dense(first_token_tensor)
  462. pooled_output = self.activation(pooled_output)
  463. return pooled_output
  464. @auto_docstring
  465. class LiltPreTrainedModel(PreTrainedModel):
  466. config: LiltConfig
  467. base_model_prefix = "lilt"
  468. supports_gradient_checkpointing = True
  469. _no_split_modules = []
  470. def _init_weights(self, module):
  471. """Initialize the weights"""
  472. if isinstance(module, nn.Linear):
  473. # Slightly different from the TF version which uses truncated_normal for initialization
  474. # cf https://github.com/pytorch/pytorch/pull/5617
  475. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  476. if module.bias is not None:
  477. module.bias.data.zero_()
  478. elif isinstance(module, nn.Embedding):
  479. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  480. if module.padding_idx is not None:
  481. module.weight.data[module.padding_idx].zero_()
  482. elif isinstance(module, nn.LayerNorm):
  483. module.bias.data.zero_()
  484. module.weight.data.fill_(1.0)
  485. @auto_docstring
  486. class LiltModel(LiltPreTrainedModel):
  487. def __init__(self, config, add_pooling_layer=True):
  488. r"""
  489. add_pooling_layer (bool, *optional*, defaults to `True`):
  490. Whether to add a pooling layer
  491. """
  492. super().__init__(config)
  493. self.config = config
  494. self.embeddings = LiltTextEmbeddings(config)
  495. self.layout_embeddings = LiltLayoutEmbeddings(config)
  496. self.encoder = LiltEncoder(config)
  497. self.pooler = LiltPooler(config) if add_pooling_layer else None
  498. # Initialize weights and apply final processing
  499. self.post_init()
  500. def get_input_embeddings(self):
  501. return self.embeddings.word_embeddings
  502. def set_input_embeddings(self, value):
  503. self.embeddings.word_embeddings = value
  504. def _prune_heads(self, heads_to_prune):
  505. """
  506. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  507. class PreTrainedModel
  508. """
  509. for layer, heads in heads_to_prune.items():
  510. self.encoder.layer[layer].attention.prune_heads(heads)
  511. @auto_docstring
  512. def forward(
  513. self,
  514. input_ids: Optional[torch.Tensor] = None,
  515. bbox: Optional[torch.Tensor] = None,
  516. attention_mask: Optional[torch.Tensor] = None,
  517. token_type_ids: Optional[torch.Tensor] = None,
  518. position_ids: Optional[torch.Tensor] = None,
  519. head_mask: Optional[torch.Tensor] = None,
  520. inputs_embeds: Optional[torch.Tensor] = None,
  521. output_attentions: Optional[bool] = None,
  522. output_hidden_states: Optional[bool] = None,
  523. return_dict: Optional[bool] = None,
  524. ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPooling]:
  525. r"""
  526. bbox (`torch.LongTensor` of shape `(batch_size, sequence_length, 4)`, *optional*):
  527. Bounding boxes of each input sequence tokens. Selected in the range `[0,
  528. config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1)
  529. format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1,
  530. y1) represents the position of the lower right corner. See [Overview](#Overview) for normalization.
  531. Examples:
  532. ```python
  533. >>> from transformers import AutoTokenizer, AutoModel
  534. >>> from datasets import load_dataset
  535. >>> tokenizer = AutoTokenizer.from_pretrained("SCUT-DLVCLab/lilt-roberta-en-base")
  536. >>> model = AutoModel.from_pretrained("SCUT-DLVCLab/lilt-roberta-en-base")
  537. >>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train")
  538. >>> example = dataset[0]
  539. >>> words = example["tokens"]
  540. >>> boxes = example["bboxes"]
  541. >>> encoding = tokenizer(words, boxes=boxes, return_tensors="pt")
  542. >>> outputs = model(**encoding)
  543. >>> last_hidden_states = outputs.last_hidden_state
  544. ```"""
  545. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  546. output_hidden_states = (
  547. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  548. )
  549. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  550. if input_ids is not None and inputs_embeds is not None:
  551. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  552. elif input_ids is not None:
  553. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  554. input_shape = input_ids.size()
  555. elif inputs_embeds is not None:
  556. input_shape = inputs_embeds.size()[:-1]
  557. else:
  558. raise ValueError("You have to specify either input_ids or inputs_embeds")
  559. batch_size, seq_length = input_shape
  560. device = input_ids.device if input_ids is not None else inputs_embeds.device
  561. if bbox is None:
  562. bbox = torch.zeros(input_shape + (4,), dtype=torch.long, device=device)
  563. if attention_mask is None:
  564. attention_mask = torch.ones(((batch_size, seq_length)), device=device)
  565. if token_type_ids is None:
  566. if hasattr(self.embeddings, "token_type_ids"):
  567. buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
  568. buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
  569. token_type_ids = buffered_token_type_ids_expanded
  570. else:
  571. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
  572. # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
  573. # ourselves in which case we just need to make it broadcastable to all heads.
  574. extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
  575. # Prepare head mask if needed
  576. # 1.0 in head_mask indicate we keep the head
  577. # attention_probs has shape bsz x n_heads x N x N
  578. # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
  579. # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
  580. head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
  581. embedding_output, position_ids = self.embeddings(
  582. input_ids=input_ids,
  583. position_ids=position_ids,
  584. token_type_ids=token_type_ids,
  585. inputs_embeds=inputs_embeds,
  586. )
  587. layout_embedding_output = self.layout_embeddings(bbox=bbox, position_ids=position_ids)
  588. encoder_outputs = self.encoder(
  589. embedding_output,
  590. layout_embedding_output,
  591. attention_mask=extended_attention_mask,
  592. head_mask=head_mask,
  593. output_attentions=output_attentions,
  594. output_hidden_states=output_hidden_states,
  595. return_dict=return_dict,
  596. )
  597. sequence_output = encoder_outputs[0]
  598. pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
  599. if not return_dict:
  600. return (sequence_output, pooled_output) + encoder_outputs[1:]
  601. return BaseModelOutputWithPooling(
  602. last_hidden_state=sequence_output,
  603. pooler_output=pooled_output,
  604. hidden_states=encoder_outputs.hidden_states,
  605. attentions=encoder_outputs.attentions,
  606. )
  607. @auto_docstring(
  608. custom_intro="""
  609. LiLT Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
  610. output) e.g. for GLUE tasks.
  611. """
  612. )
  613. class LiltForSequenceClassification(LiltPreTrainedModel):
  614. # Copied from transformers.models.roberta.modeling_roberta.RobertaForSequenceClassification.__init__ with Roberta->Lilt, roberta->lilt
  615. def __init__(self, config):
  616. super().__init__(config)
  617. self.num_labels = config.num_labels
  618. self.config = config
  619. self.lilt = LiltModel(config, add_pooling_layer=False)
  620. self.classifier = LiltClassificationHead(config)
  621. # Initialize weights and apply final processing
  622. self.post_init()
  623. @auto_docstring
  624. def forward(
  625. self,
  626. input_ids: Optional[torch.LongTensor] = None,
  627. bbox: Optional[torch.Tensor] = None,
  628. attention_mask: Optional[torch.FloatTensor] = None,
  629. token_type_ids: Optional[torch.LongTensor] = None,
  630. position_ids: Optional[torch.LongTensor] = None,
  631. head_mask: Optional[torch.FloatTensor] = None,
  632. inputs_embeds: Optional[torch.FloatTensor] = None,
  633. labels: Optional[torch.LongTensor] = None,
  634. output_attentions: Optional[bool] = None,
  635. output_hidden_states: Optional[bool] = None,
  636. return_dict: Optional[bool] = None,
  637. ) -> Union[tuple[torch.Tensor], SequenceClassifierOutput]:
  638. r"""
  639. bbox (`torch.LongTensor` of shape `(batch_size, sequence_length, 4)`, *optional*):
  640. Bounding boxes of each input sequence tokens. Selected in the range `[0,
  641. config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1)
  642. format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1,
  643. y1) represents the position of the lower right corner. See [Overview](#Overview) for normalization.
  644. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  645. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  646. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  647. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  648. Examples:
  649. ```python
  650. >>> from transformers import AutoTokenizer, AutoModelForSequenceClassification
  651. >>> from datasets import load_dataset
  652. >>> tokenizer = AutoTokenizer.from_pretrained("SCUT-DLVCLab/lilt-roberta-en-base")
  653. >>> model = AutoModelForSequenceClassification.from_pretrained("SCUT-DLVCLab/lilt-roberta-en-base")
  654. >>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train")
  655. >>> example = dataset[0]
  656. >>> words = example["tokens"]
  657. >>> boxes = example["bboxes"]
  658. >>> encoding = tokenizer(words, boxes=boxes, return_tensors="pt")
  659. >>> outputs = model(**encoding)
  660. >>> predicted_class_idx = outputs.logits.argmax(-1).item()
  661. >>> predicted_class = model.config.id2label[predicted_class_idx]
  662. ```"""
  663. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  664. outputs = self.lilt(
  665. input_ids,
  666. bbox=bbox,
  667. attention_mask=attention_mask,
  668. token_type_ids=token_type_ids,
  669. position_ids=position_ids,
  670. head_mask=head_mask,
  671. inputs_embeds=inputs_embeds,
  672. output_attentions=output_attentions,
  673. output_hidden_states=output_hidden_states,
  674. return_dict=return_dict,
  675. )
  676. sequence_output = outputs[0]
  677. logits = self.classifier(sequence_output)
  678. loss = None
  679. if labels is not None:
  680. # move labels to correct device to enable model parallelism
  681. labels = labels.to(logits.device)
  682. if self.config.problem_type is None:
  683. if self.num_labels == 1:
  684. self.config.problem_type = "regression"
  685. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  686. self.config.problem_type = "single_label_classification"
  687. else:
  688. self.config.problem_type = "multi_label_classification"
  689. if self.config.problem_type == "regression":
  690. loss_fct = MSELoss()
  691. if self.num_labels == 1:
  692. loss = loss_fct(logits.squeeze(), labels.squeeze())
  693. else:
  694. loss = loss_fct(logits, labels)
  695. elif self.config.problem_type == "single_label_classification":
  696. loss_fct = CrossEntropyLoss()
  697. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  698. elif self.config.problem_type == "multi_label_classification":
  699. loss_fct = BCEWithLogitsLoss()
  700. loss = loss_fct(logits, labels)
  701. if not return_dict:
  702. output = (logits,) + outputs[2:]
  703. return ((loss,) + output) if loss is not None else output
  704. return SequenceClassifierOutput(
  705. loss=loss,
  706. logits=logits,
  707. hidden_states=outputs.hidden_states,
  708. attentions=outputs.attentions,
  709. )
  710. @auto_docstring
  711. class LiltForTokenClassification(LiltPreTrainedModel):
  712. # Copied from transformers.models.roberta.modeling_roberta.RobertaForTokenClassification.__init__ with Roberta->Lilt, roberta->lilt
  713. def __init__(self, config):
  714. super().__init__(config)
  715. self.num_labels = config.num_labels
  716. self.lilt = LiltModel(config, add_pooling_layer=False)
  717. classifier_dropout = (
  718. config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
  719. )
  720. self.dropout = nn.Dropout(classifier_dropout)
  721. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  722. # Initialize weights and apply final processing
  723. self.post_init()
  724. @auto_docstring
  725. def forward(
  726. self,
  727. input_ids: Optional[torch.LongTensor] = None,
  728. bbox: Optional[torch.LongTensor] = None,
  729. attention_mask: Optional[torch.FloatTensor] = None,
  730. token_type_ids: Optional[torch.LongTensor] = None,
  731. position_ids: Optional[torch.LongTensor] = None,
  732. head_mask: Optional[torch.FloatTensor] = None,
  733. inputs_embeds: Optional[torch.FloatTensor] = None,
  734. labels: Optional[torch.LongTensor] = None,
  735. output_attentions: Optional[bool] = None,
  736. output_hidden_states: Optional[bool] = None,
  737. return_dict: Optional[bool] = None,
  738. ) -> Union[tuple[torch.Tensor], TokenClassifierOutput]:
  739. r"""
  740. bbox (`torch.LongTensor` of shape `(batch_size, sequence_length, 4)`, *optional*):
  741. Bounding boxes of each input sequence tokens. Selected in the range `[0,
  742. config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1)
  743. format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1,
  744. y1) represents the position of the lower right corner. See [Overview](#Overview) for normalization.
  745. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  746. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  747. Examples:
  748. ```python
  749. >>> from transformers import AutoTokenizer, AutoModelForTokenClassification
  750. >>> from datasets import load_dataset
  751. >>> tokenizer = AutoTokenizer.from_pretrained("SCUT-DLVCLab/lilt-roberta-en-base")
  752. >>> model = AutoModelForTokenClassification.from_pretrained("SCUT-DLVCLab/lilt-roberta-en-base")
  753. >>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train")
  754. >>> example = dataset[0]
  755. >>> words = example["tokens"]
  756. >>> boxes = example["bboxes"]
  757. >>> encoding = tokenizer(words, boxes=boxes, return_tensors="pt")
  758. >>> outputs = model(**encoding)
  759. >>> predicted_class_indices = outputs.logits.argmax(-1)
  760. ```"""
  761. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  762. outputs = self.lilt(
  763. input_ids,
  764. bbox=bbox,
  765. attention_mask=attention_mask,
  766. token_type_ids=token_type_ids,
  767. position_ids=position_ids,
  768. head_mask=head_mask,
  769. inputs_embeds=inputs_embeds,
  770. output_attentions=output_attentions,
  771. output_hidden_states=output_hidden_states,
  772. return_dict=return_dict,
  773. )
  774. sequence_output = outputs[0]
  775. sequence_output = self.dropout(sequence_output)
  776. logits = self.classifier(sequence_output)
  777. loss = None
  778. if labels is not None:
  779. # move labels to correct device to enable model parallelism
  780. labels = labels.to(logits.device)
  781. loss_fct = CrossEntropyLoss()
  782. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  783. if not return_dict:
  784. output = (logits,) + outputs[2:]
  785. return ((loss,) + output) if loss is not None else output
  786. return TokenClassifierOutput(
  787. loss=loss,
  788. logits=logits,
  789. hidden_states=outputs.hidden_states,
  790. attentions=outputs.attentions,
  791. )
  792. # Copied from transformers.models.roberta.modeling_roberta.RobertaClassificationHead with Roberta->Lilt
  793. class LiltClassificationHead(nn.Module):
  794. """Head for sentence-level classification tasks."""
  795. def __init__(self, config):
  796. super().__init__()
  797. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  798. classifier_dropout = (
  799. config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
  800. )
  801. self.dropout = nn.Dropout(classifier_dropout)
  802. self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
  803. def forward(self, features, **kwargs):
  804. x = features[:, 0, :] # take <s> token (equiv. to [CLS])
  805. x = self.dropout(x)
  806. x = self.dense(x)
  807. x = torch.tanh(x)
  808. x = self.dropout(x)
  809. x = self.out_proj(x)
  810. return x
  811. @auto_docstring
  812. class LiltForQuestionAnswering(LiltPreTrainedModel):
  813. # Copied from transformers.models.roberta.modeling_roberta.RobertaForQuestionAnswering.__init__ with Roberta->Lilt, roberta->lilt
  814. def __init__(self, config):
  815. super().__init__(config)
  816. self.num_labels = config.num_labels
  817. self.lilt = LiltModel(config, add_pooling_layer=False)
  818. self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
  819. # Initialize weights and apply final processing
  820. self.post_init()
  821. @auto_docstring
  822. def forward(
  823. self,
  824. input_ids: Optional[torch.LongTensor] = None,
  825. bbox: Optional[torch.LongTensor] = None,
  826. attention_mask: Optional[torch.FloatTensor] = None,
  827. token_type_ids: Optional[torch.LongTensor] = None,
  828. position_ids: Optional[torch.LongTensor] = None,
  829. head_mask: Optional[torch.FloatTensor] = None,
  830. inputs_embeds: Optional[torch.FloatTensor] = None,
  831. start_positions: Optional[torch.LongTensor] = None,
  832. end_positions: Optional[torch.LongTensor] = None,
  833. output_attentions: Optional[bool] = None,
  834. output_hidden_states: Optional[bool] = None,
  835. return_dict: Optional[bool] = None,
  836. ) -> Union[tuple[torch.Tensor], QuestionAnsweringModelOutput]:
  837. r"""
  838. bbox (`torch.LongTensor` of shape `(batch_size, sequence_length, 4)`, *optional*):
  839. Bounding boxes of each input sequence tokens. Selected in the range `[0,
  840. config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1)
  841. format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1,
  842. y1) represents the position of the lower right corner. See [Overview](#Overview) for normalization.
  843. Examples:
  844. ```python
  845. >>> from transformers import AutoTokenizer, AutoModelForQuestionAnswering
  846. >>> from datasets import load_dataset
  847. >>> tokenizer = AutoTokenizer.from_pretrained("SCUT-DLVCLab/lilt-roberta-en-base")
  848. >>> model = AutoModelForQuestionAnswering.from_pretrained("SCUT-DLVCLab/lilt-roberta-en-base")
  849. >>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train")
  850. >>> example = dataset[0]
  851. >>> words = example["tokens"]
  852. >>> boxes = example["bboxes"]
  853. >>> encoding = tokenizer(words, boxes=boxes, return_tensors="pt")
  854. >>> outputs = model(**encoding)
  855. >>> answer_start_index = outputs.start_logits.argmax()
  856. >>> answer_end_index = outputs.end_logits.argmax()
  857. >>> predict_answer_tokens = encoding.input_ids[0, answer_start_index : answer_end_index + 1]
  858. >>> predicted_answer = tokenizer.decode(predict_answer_tokens)
  859. ```"""
  860. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  861. outputs = self.lilt(
  862. input_ids,
  863. bbox=bbox,
  864. attention_mask=attention_mask,
  865. token_type_ids=token_type_ids,
  866. position_ids=position_ids,
  867. head_mask=head_mask,
  868. inputs_embeds=inputs_embeds,
  869. output_attentions=output_attentions,
  870. output_hidden_states=output_hidden_states,
  871. return_dict=return_dict,
  872. )
  873. sequence_output = outputs[0]
  874. logits = self.qa_outputs(sequence_output)
  875. start_logits, end_logits = logits.split(1, dim=-1)
  876. start_logits = start_logits.squeeze(-1).contiguous()
  877. end_logits = end_logits.squeeze(-1).contiguous()
  878. total_loss = None
  879. if start_positions is not None and end_positions is not None:
  880. # If we are on multi-GPU, split add a dimension
  881. if len(start_positions.size()) > 1:
  882. start_positions = start_positions.squeeze(-1)
  883. if len(end_positions.size()) > 1:
  884. end_positions = end_positions.squeeze(-1)
  885. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  886. ignored_index = start_logits.size(1)
  887. start_positions = start_positions.clamp(0, ignored_index)
  888. end_positions = end_positions.clamp(0, ignored_index)
  889. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  890. start_loss = loss_fct(start_logits, start_positions)
  891. end_loss = loss_fct(end_logits, end_positions)
  892. total_loss = (start_loss + end_loss) / 2
  893. if not return_dict:
  894. output = (start_logits, end_logits) + outputs[2:]
  895. return ((total_loss,) + output) if total_loss is not None else output
  896. return QuestionAnsweringModelOutput(
  897. loss=total_loss,
  898. start_logits=start_logits,
  899. end_logits=end_logits,
  900. hidden_states=outputs.hidden_states,
  901. attentions=outputs.attentions,
  902. )
  903. __all__ = [
  904. "LiltForQuestionAnswering",
  905. "LiltForSequenceClassification",
  906. "LiltForTokenClassification",
  907. "LiltModel",
  908. "LiltPreTrainedModel",
  909. ]