modeling_layoutlm.py 46 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144
  1. # coding=utf-8
  2. # Copyright 2018 The Microsoft Research Asia LayoutLM Team Authors and the HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch LayoutLM model."""
  16. from typing import Callable, Optional, Union
  17. import torch
  18. from torch import nn
  19. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  20. from ...activations import ACT2FN
  21. from ...modeling_layers import GradientCheckpointingLayer
  22. from ...modeling_outputs import (
  23. BaseModelOutput,
  24. BaseModelOutputWithPooling,
  25. MaskedLMOutput,
  26. QuestionAnsweringModelOutput,
  27. SequenceClassifierOutput,
  28. TokenClassifierOutput,
  29. )
  30. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  31. from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
  32. from ...utils import auto_docstring, can_return_tuple, logging
  33. from .configuration_layoutlm import LayoutLMConfig
  34. logger = logging.get_logger(__name__)
  35. LayoutLMLayerNorm = nn.LayerNorm
  36. class LayoutLMEmbeddings(nn.Module):
  37. """Construct the embeddings from word, position and token_type embeddings."""
  38. def __init__(self, config):
  39. super().__init__()
  40. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  41. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
  42. self.x_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.hidden_size)
  43. self.y_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.hidden_size)
  44. self.h_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.hidden_size)
  45. self.w_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.hidden_size)
  46. self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
  47. self.LayerNorm = LayoutLMLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  48. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  49. self.register_buffer(
  50. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  51. )
  52. def forward(
  53. self,
  54. input_ids=None,
  55. bbox=None,
  56. token_type_ids=None,
  57. position_ids=None,
  58. inputs_embeds=None,
  59. ):
  60. if input_ids is not None:
  61. input_shape = input_ids.size()
  62. else:
  63. input_shape = inputs_embeds.size()[:-1]
  64. seq_length = input_shape[1]
  65. device = input_ids.device if input_ids is not None else inputs_embeds.device
  66. if position_ids is None:
  67. position_ids = self.position_ids[:, :seq_length]
  68. if token_type_ids is None:
  69. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
  70. if inputs_embeds is None:
  71. inputs_embeds = self.word_embeddings(input_ids)
  72. words_embeddings = inputs_embeds
  73. position_embeddings = self.position_embeddings(position_ids)
  74. try:
  75. left_position_embeddings = self.x_position_embeddings(bbox[:, :, 0])
  76. upper_position_embeddings = self.y_position_embeddings(bbox[:, :, 1])
  77. right_position_embeddings = self.x_position_embeddings(bbox[:, :, 2])
  78. lower_position_embeddings = self.y_position_embeddings(bbox[:, :, 3])
  79. except IndexError as e:
  80. raise IndexError("The `bbox`coordinate values should be within 0-1000 range.") from e
  81. h_position_embeddings = self.h_position_embeddings(bbox[:, :, 3] - bbox[:, :, 1])
  82. w_position_embeddings = self.w_position_embeddings(bbox[:, :, 2] - bbox[:, :, 0])
  83. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  84. embeddings = (
  85. words_embeddings
  86. + position_embeddings
  87. + left_position_embeddings
  88. + upper_position_embeddings
  89. + right_position_embeddings
  90. + lower_position_embeddings
  91. + h_position_embeddings
  92. + w_position_embeddings
  93. + token_type_embeddings
  94. )
  95. embeddings = self.LayerNorm(embeddings)
  96. embeddings = self.dropout(embeddings)
  97. return embeddings
  98. # Copied from transformers.models.align.modeling_align.eager_attention_forward
  99. def eager_attention_forward(
  100. module: nn.Module,
  101. query: torch.Tensor,
  102. key: torch.Tensor,
  103. value: torch.Tensor,
  104. attention_mask: Optional[torch.Tensor],
  105. scaling: float,
  106. dropout: float = 0.0,
  107. head_mask: Optional[torch.Tensor] = None,
  108. **kwargs,
  109. ):
  110. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  111. if attention_mask is not None:
  112. causal_mask = attention_mask[:, :, :, : key.shape[-2]]
  113. attn_weights = attn_weights + causal_mask
  114. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  115. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  116. if head_mask is not None:
  117. attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
  118. attn_output = torch.matmul(attn_weights, value)
  119. attn_output = attn_output.transpose(1, 2).contiguous()
  120. return attn_output, attn_weights
  121. # Copied from transformers.models.align.modeling_align.AlignTextSelfAttention with AlignText->LayoutLM
  122. class LayoutLMSelfAttention(nn.Module):
  123. def __init__(self, config):
  124. super().__init__()
  125. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  126. raise ValueError(
  127. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  128. f"heads ({config.num_attention_heads})"
  129. )
  130. self.config = config
  131. self.num_attention_heads = config.num_attention_heads
  132. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  133. self.all_head_size = self.num_attention_heads * self.attention_head_size
  134. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  135. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  136. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  137. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  138. self.attention_dropout = config.attention_probs_dropout_prob
  139. self.scaling = self.attention_head_size**-0.5
  140. def forward(
  141. self,
  142. hidden_states: torch.Tensor,
  143. attention_mask: Optional[torch.FloatTensor] = None,
  144. head_mask: Optional[torch.FloatTensor] = None,
  145. output_attentions: Optional[bool] = False,
  146. **kwargs,
  147. ) -> tuple[torch.Tensor]:
  148. input_shape = hidden_states.shape[:-1]
  149. hidden_shape = (*input_shape, -1, self.attention_head_size)
  150. query_states = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
  151. key_states = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
  152. value_states = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
  153. attention_interface: Callable = eager_attention_forward
  154. if self.config._attn_implementation != "eager":
  155. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  156. attn_output, attn_weights = attention_interface(
  157. self,
  158. query_states,
  159. key_states,
  160. value_states,
  161. attention_mask,
  162. dropout=0.0 if not self.training else self.attention_dropout,
  163. scaling=self.scaling,
  164. head_mask=head_mask,
  165. **kwargs,
  166. )
  167. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  168. outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
  169. return outputs
  170. # Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->LayoutLM
  171. class LayoutLMSelfOutput(nn.Module):
  172. def __init__(self, config):
  173. super().__init__()
  174. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  175. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  176. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  177. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  178. hidden_states = self.dense(hidden_states)
  179. hidden_states = self.dropout(hidden_states)
  180. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  181. return hidden_states
  182. # Copied from transformers.models.align.modeling_align.AlignTextAttention with AlignText->LayoutLM
  183. class LayoutLMAttention(nn.Module):
  184. def __init__(self, config):
  185. super().__init__()
  186. self.self = LayoutLMSelfAttention(config)
  187. self.output = LayoutLMSelfOutput(config)
  188. self.pruned_heads = set()
  189. def prune_heads(self, heads):
  190. if len(heads) == 0:
  191. return
  192. heads, index = find_pruneable_heads_and_indices(
  193. heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
  194. )
  195. # Prune linear layers
  196. self.self.query = prune_linear_layer(self.self.query, index)
  197. self.self.key = prune_linear_layer(self.self.key, index)
  198. self.self.value = prune_linear_layer(self.self.value, index)
  199. self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
  200. # Update hyper params and store pruned heads
  201. self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
  202. self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
  203. self.pruned_heads = self.pruned_heads.union(heads)
  204. def forward(
  205. self,
  206. hidden_states: torch.Tensor,
  207. attention_mask: Optional[torch.FloatTensor] = None,
  208. head_mask: Optional[torch.FloatTensor] = None,
  209. output_attentions: Optional[bool] = False,
  210. **kwargs,
  211. ) -> tuple[torch.Tensor]:
  212. self_outputs = self.self(
  213. hidden_states,
  214. attention_mask=attention_mask,
  215. head_mask=head_mask,
  216. output_attentions=output_attentions,
  217. **kwargs,
  218. )
  219. attention_output = self.output(self_outputs[0], hidden_states)
  220. outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
  221. return outputs
  222. # Copied from transformers.models.bert.modeling_bert.BertIntermediate
  223. class LayoutLMIntermediate(nn.Module):
  224. def __init__(self, config):
  225. super().__init__()
  226. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  227. if isinstance(config.hidden_act, str):
  228. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  229. else:
  230. self.intermediate_act_fn = config.hidden_act
  231. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  232. hidden_states = self.dense(hidden_states)
  233. hidden_states = self.intermediate_act_fn(hidden_states)
  234. return hidden_states
  235. # Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->LayoutLM
  236. class LayoutLMOutput(nn.Module):
  237. def __init__(self, config):
  238. super().__init__()
  239. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  240. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  241. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  242. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  243. hidden_states = self.dense(hidden_states)
  244. hidden_states = self.dropout(hidden_states)
  245. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  246. return hidden_states
  247. # Copied from transformers.models.align.modeling_align.AlignTextLayer with AlignText->LayoutLM
  248. class LayoutLMLayer(GradientCheckpointingLayer):
  249. def __init__(self, config):
  250. super().__init__()
  251. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  252. self.seq_len_dim = 1
  253. self.attention = LayoutLMAttention(config)
  254. self.intermediate = LayoutLMIntermediate(config)
  255. self.output = LayoutLMOutput(config)
  256. def forward(
  257. self,
  258. hidden_states: torch.Tensor,
  259. attention_mask: Optional[torch.FloatTensor] = None,
  260. head_mask: Optional[torch.FloatTensor] = None,
  261. output_attentions: Optional[bool] = False,
  262. **kwargs,
  263. ) -> tuple[torch.Tensor]:
  264. self_attention_outputs = self.attention(
  265. hidden_states,
  266. attention_mask=attention_mask,
  267. head_mask=head_mask,
  268. output_attentions=output_attentions,
  269. **kwargs,
  270. )
  271. attention_output = self_attention_outputs[0]
  272. outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
  273. layer_output = apply_chunking_to_forward(
  274. self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
  275. )
  276. outputs = (layer_output,) + outputs
  277. return outputs
  278. def feed_forward_chunk(self, attention_output):
  279. intermediate_output = self.intermediate(attention_output)
  280. layer_output = self.output(intermediate_output, attention_output)
  281. return layer_output
  282. # Copied from transformers.models.align.modeling_align.AlignTextEncoder with AlignText->LayoutLM
  283. class LayoutLMEncoder(nn.Module):
  284. def __init__(self, config):
  285. super().__init__()
  286. self.config = config
  287. self.layer = nn.ModuleList([LayoutLMLayer(config) for i in range(config.num_hidden_layers)])
  288. self.gradient_checkpointing = False
  289. @can_return_tuple
  290. def forward(
  291. self,
  292. hidden_states: torch.Tensor,
  293. attention_mask: Optional[torch.FloatTensor] = None,
  294. head_mask: Optional[torch.FloatTensor] = None,
  295. output_attentions: Optional[bool] = False,
  296. output_hidden_states: Optional[bool] = False,
  297. return_dict: Optional[bool] = True,
  298. **kwargs,
  299. ) -> Union[tuple[torch.Tensor], BaseModelOutput]:
  300. all_hidden_states = () if output_hidden_states else None
  301. all_self_attentions = () if output_attentions else None
  302. for i, layer_module in enumerate(self.layer):
  303. if output_hidden_states:
  304. all_hidden_states = all_hidden_states + (hidden_states,)
  305. layer_head_mask = head_mask[i] if head_mask is not None else None
  306. layer_outputs = layer_module(
  307. hidden_states=hidden_states,
  308. attention_mask=attention_mask,
  309. head_mask=layer_head_mask,
  310. output_attentions=output_attentions,
  311. **kwargs,
  312. )
  313. hidden_states = layer_outputs[0]
  314. if output_attentions:
  315. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  316. if output_hidden_states:
  317. all_hidden_states = all_hidden_states + (hidden_states,)
  318. return BaseModelOutput(
  319. last_hidden_state=hidden_states,
  320. hidden_states=all_hidden_states,
  321. attentions=all_self_attentions,
  322. )
  323. # Copied from transformers.models.bert.modeling_bert.BertPooler
  324. class LayoutLMPooler(nn.Module):
  325. def __init__(self, config):
  326. super().__init__()
  327. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  328. self.activation = nn.Tanh()
  329. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  330. # We "pool" the model by simply taking the hidden state corresponding
  331. # to the first token.
  332. first_token_tensor = hidden_states[:, 0]
  333. pooled_output = self.dense(first_token_tensor)
  334. pooled_output = self.activation(pooled_output)
  335. return pooled_output
  336. # Copied from transformers.models.bert.modeling_bert.BertPredictionHeadTransform with Bert->LayoutLM
  337. class LayoutLMPredictionHeadTransform(nn.Module):
  338. def __init__(self, config):
  339. super().__init__()
  340. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  341. if isinstance(config.hidden_act, str):
  342. self.transform_act_fn = ACT2FN[config.hidden_act]
  343. else:
  344. self.transform_act_fn = config.hidden_act
  345. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  346. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  347. hidden_states = self.dense(hidden_states)
  348. hidden_states = self.transform_act_fn(hidden_states)
  349. hidden_states = self.LayerNorm(hidden_states)
  350. return hidden_states
  351. # Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead with Bert->LayoutLM
  352. class LayoutLMLMPredictionHead(nn.Module):
  353. def __init__(self, config):
  354. super().__init__()
  355. self.transform = LayoutLMPredictionHeadTransform(config)
  356. # The output weights are the same as the input embeddings, but there is
  357. # an output-only bias for each token.
  358. self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  359. self.bias = nn.Parameter(torch.zeros(config.vocab_size))
  360. # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
  361. self.decoder.bias = self.bias
  362. def _tie_weights(self):
  363. self.decoder.bias = self.bias
  364. def forward(self, hidden_states):
  365. hidden_states = self.transform(hidden_states)
  366. hidden_states = self.decoder(hidden_states)
  367. return hidden_states
  368. # Copied from transformers.models.bert.modeling_bert.BertOnlyMLMHead with Bert->LayoutLM
  369. class LayoutLMOnlyMLMHead(nn.Module):
  370. def __init__(self, config):
  371. super().__init__()
  372. self.predictions = LayoutLMLMPredictionHead(config)
  373. def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
  374. prediction_scores = self.predictions(sequence_output)
  375. return prediction_scores
  376. @auto_docstring
  377. class LayoutLMPreTrainedModel(PreTrainedModel):
  378. config: LayoutLMConfig
  379. base_model_prefix = "layoutlm"
  380. supports_gradient_checkpointing = True
  381. def _init_weights(self, module):
  382. """Initialize the weights"""
  383. if isinstance(module, nn.Linear):
  384. # Slightly different from the TF version which uses truncated_normal for initialization
  385. # cf https://github.com/pytorch/pytorch/pull/5617
  386. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  387. if module.bias is not None:
  388. module.bias.data.zero_()
  389. elif isinstance(module, nn.Embedding):
  390. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  391. if module.padding_idx is not None:
  392. module.weight.data[module.padding_idx].zero_()
  393. elif isinstance(module, LayoutLMLayerNorm):
  394. module.bias.data.zero_()
  395. module.weight.data.fill_(1.0)
  396. elif isinstance(module, LayoutLMLMPredictionHead):
  397. module.bias.data.zero_()
  398. @auto_docstring
  399. class LayoutLMModel(LayoutLMPreTrainedModel):
  400. def __init__(self, config):
  401. super().__init__(config)
  402. self.config = config
  403. self.embeddings = LayoutLMEmbeddings(config)
  404. self.encoder = LayoutLMEncoder(config)
  405. self.pooler = LayoutLMPooler(config)
  406. # Initialize weights and apply final processing
  407. self.post_init()
  408. def get_input_embeddings(self):
  409. return self.embeddings.word_embeddings
  410. def set_input_embeddings(self, value):
  411. self.embeddings.word_embeddings = value
  412. def _prune_heads(self, heads_to_prune):
  413. """
  414. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  415. class PreTrainedModel
  416. """
  417. for layer, heads in heads_to_prune.items():
  418. self.encoder.layer[layer].attention.prune_heads(heads)
  419. @can_return_tuple
  420. @auto_docstring
  421. def forward(
  422. self,
  423. input_ids: Optional[torch.LongTensor] = None,
  424. bbox: Optional[torch.LongTensor] = None,
  425. attention_mask: Optional[torch.FloatTensor] = None,
  426. token_type_ids: Optional[torch.LongTensor] = None,
  427. position_ids: Optional[torch.LongTensor] = None,
  428. head_mask: Optional[torch.FloatTensor] = None,
  429. inputs_embeds: Optional[torch.FloatTensor] = None,
  430. output_attentions: Optional[bool] = None,
  431. output_hidden_states: Optional[bool] = None,
  432. return_dict: Optional[bool] = None,
  433. ) -> Union[tuple, BaseModelOutputWithPooling]:
  434. r"""
  435. bbox (`torch.LongTensor` of shape `(batch_size, sequence_length, 4)`, *optional*):
  436. Bounding boxes of each input sequence tokens. Selected in the range `[0,
  437. config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1)
  438. format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1,
  439. y1) represents the position of the lower right corner. See [Overview](#Overview) for normalization.
  440. Examples:
  441. ```python
  442. >>> from transformers import AutoTokenizer, LayoutLMModel
  443. >>> import torch
  444. >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/layoutlm-base-uncased")
  445. >>> model = LayoutLMModel.from_pretrained("microsoft/layoutlm-base-uncased")
  446. >>> words = ["Hello", "world"]
  447. >>> normalized_word_boxes = [637, 773, 693, 782], [698, 773, 733, 782]
  448. >>> token_boxes = []
  449. >>> for word, box in zip(words, normalized_word_boxes):
  450. ... word_tokens = tokenizer.tokenize(word)
  451. ... token_boxes.extend([box] * len(word_tokens))
  452. >>> # add bounding boxes of cls + sep tokens
  453. >>> token_boxes = [[0, 0, 0, 0]] + token_boxes + [[1000, 1000, 1000, 1000]]
  454. >>> encoding = tokenizer(" ".join(words), return_tensors="pt")
  455. >>> input_ids = encoding["input_ids"]
  456. >>> attention_mask = encoding["attention_mask"]
  457. >>> token_type_ids = encoding["token_type_ids"]
  458. >>> bbox = torch.tensor([token_boxes])
  459. >>> outputs = model(
  460. ... input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, token_type_ids=token_type_ids
  461. ... )
  462. >>> last_hidden_states = outputs.last_hidden_state
  463. ```"""
  464. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  465. output_hidden_states = (
  466. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  467. )
  468. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  469. if input_ids is not None and inputs_embeds is not None:
  470. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  471. elif input_ids is not None:
  472. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  473. input_shape = input_ids.size()
  474. elif inputs_embeds is not None:
  475. input_shape = inputs_embeds.size()[:-1]
  476. else:
  477. raise ValueError("You have to specify either input_ids or inputs_embeds")
  478. device = input_ids.device if input_ids is not None else inputs_embeds.device
  479. if attention_mask is None:
  480. attention_mask = torch.ones(input_shape, device=device)
  481. if token_type_ids is None:
  482. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
  483. if bbox is None:
  484. bbox = torch.zeros(input_shape + (4,), dtype=torch.long, device=device)
  485. extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
  486. extended_attention_mask = extended_attention_mask.to(dtype=self.dtype)
  487. extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(self.dtype).min
  488. if head_mask is not None:
  489. if head_mask.dim() == 1:
  490. head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
  491. head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
  492. elif head_mask.dim() == 2:
  493. head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
  494. head_mask = head_mask.to(dtype=next(self.parameters()).dtype)
  495. else:
  496. head_mask = [None] * self.config.num_hidden_layers
  497. embedding_output = self.embeddings(
  498. input_ids=input_ids,
  499. bbox=bbox,
  500. position_ids=position_ids,
  501. token_type_ids=token_type_ids,
  502. inputs_embeds=inputs_embeds,
  503. )
  504. encoder_outputs = self.encoder(
  505. embedding_output,
  506. extended_attention_mask,
  507. head_mask=head_mask,
  508. output_attentions=output_attentions,
  509. output_hidden_states=output_hidden_states,
  510. return_dict=True,
  511. )
  512. sequence_output = encoder_outputs[0]
  513. pooled_output = self.pooler(sequence_output)
  514. return BaseModelOutputWithPooling(
  515. last_hidden_state=sequence_output,
  516. pooler_output=pooled_output,
  517. hidden_states=encoder_outputs.hidden_states,
  518. attentions=encoder_outputs.attentions,
  519. )
  520. @auto_docstring
  521. class LayoutLMForMaskedLM(LayoutLMPreTrainedModel):
  522. _tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
  523. def __init__(self, config):
  524. super().__init__(config)
  525. self.layoutlm = LayoutLMModel(config)
  526. self.cls = LayoutLMOnlyMLMHead(config)
  527. # Initialize weights and apply final processing
  528. self.post_init()
  529. def get_input_embeddings(self):
  530. return self.layoutlm.embeddings.word_embeddings
  531. def get_output_embeddings(self):
  532. return self.cls.predictions.decoder
  533. def set_output_embeddings(self, new_embeddings):
  534. self.cls.predictions.decoder = new_embeddings
  535. self.cls.predictions.bias = new_embeddings.bias
  536. @can_return_tuple
  537. @auto_docstring
  538. def forward(
  539. self,
  540. input_ids: Optional[torch.LongTensor] = None,
  541. bbox: Optional[torch.LongTensor] = None,
  542. attention_mask: Optional[torch.FloatTensor] = None,
  543. token_type_ids: Optional[torch.LongTensor] = None,
  544. position_ids: Optional[torch.LongTensor] = None,
  545. head_mask: Optional[torch.FloatTensor] = None,
  546. inputs_embeds: Optional[torch.FloatTensor] = None,
  547. labels: Optional[torch.LongTensor] = None,
  548. output_attentions: Optional[bool] = None,
  549. output_hidden_states: Optional[bool] = None,
  550. return_dict: Optional[bool] = None,
  551. ) -> Union[tuple, MaskedLMOutput]:
  552. r"""
  553. bbox (`torch.LongTensor` of shape `(batch_size, sequence_length, 4)`, *optional*):
  554. Bounding boxes of each input sequence tokens. Selected in the range `[0,
  555. config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1)
  556. format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1,
  557. y1) represents the position of the lower right corner. See [Overview](#Overview) for normalization.
  558. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  559. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  560. config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
  561. loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  562. Examples:
  563. ```python
  564. >>> from transformers import AutoTokenizer, LayoutLMForMaskedLM
  565. >>> import torch
  566. >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/layoutlm-base-uncased")
  567. >>> model = LayoutLMForMaskedLM.from_pretrained("microsoft/layoutlm-base-uncased")
  568. >>> words = ["Hello", "[MASK]"]
  569. >>> normalized_word_boxes = [637, 773, 693, 782], [698, 773, 733, 782]
  570. >>> token_boxes = []
  571. >>> for word, box in zip(words, normalized_word_boxes):
  572. ... word_tokens = tokenizer.tokenize(word)
  573. ... token_boxes.extend([box] * len(word_tokens))
  574. >>> # add bounding boxes of cls + sep tokens
  575. >>> token_boxes = [[0, 0, 0, 0]] + token_boxes + [[1000, 1000, 1000, 1000]]
  576. >>> encoding = tokenizer(" ".join(words), return_tensors="pt")
  577. >>> input_ids = encoding["input_ids"]
  578. >>> attention_mask = encoding["attention_mask"]
  579. >>> token_type_ids = encoding["token_type_ids"]
  580. >>> bbox = torch.tensor([token_boxes])
  581. >>> labels = tokenizer("Hello world", return_tensors="pt")["input_ids"]
  582. >>> outputs = model(
  583. ... input_ids=input_ids,
  584. ... bbox=bbox,
  585. ... attention_mask=attention_mask,
  586. ... token_type_ids=token_type_ids,
  587. ... labels=labels,
  588. ... )
  589. >>> loss = outputs.loss
  590. ```"""
  591. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  592. outputs = self.layoutlm(
  593. input_ids,
  594. bbox,
  595. attention_mask=attention_mask,
  596. token_type_ids=token_type_ids,
  597. position_ids=position_ids,
  598. head_mask=head_mask,
  599. inputs_embeds=inputs_embeds,
  600. output_attentions=output_attentions,
  601. output_hidden_states=output_hidden_states,
  602. return_dict=True,
  603. )
  604. sequence_output = outputs[0]
  605. prediction_scores = self.cls(sequence_output)
  606. masked_lm_loss = None
  607. if labels is not None:
  608. loss_fct = CrossEntropyLoss()
  609. masked_lm_loss = loss_fct(
  610. prediction_scores.view(-1, self.config.vocab_size),
  611. labels.view(-1),
  612. )
  613. return MaskedLMOutput(
  614. loss=masked_lm_loss,
  615. logits=prediction_scores,
  616. hidden_states=outputs.hidden_states,
  617. attentions=outputs.attentions,
  618. )
  619. @auto_docstring(
  620. custom_intro="""
  621. LayoutLM Model with a sequence classification head on top (a linear layer on top of the pooled output) e.g. for
  622. document image classification tasks such as the [RVL-CDIP](https://www.cs.cmu.edu/~aharley/rvl-cdip/) dataset.
  623. """
  624. )
  625. class LayoutLMForSequenceClassification(LayoutLMPreTrainedModel):
  626. def __init__(self, config):
  627. super().__init__(config)
  628. self.num_labels = config.num_labels
  629. self.layoutlm = LayoutLMModel(config)
  630. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  631. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  632. # Initialize weights and apply final processing
  633. self.post_init()
  634. def get_input_embeddings(self):
  635. return self.layoutlm.embeddings.word_embeddings
  636. @can_return_tuple
  637. @auto_docstring
  638. def forward(
  639. self,
  640. input_ids: Optional[torch.LongTensor] = None,
  641. bbox: Optional[torch.LongTensor] = None,
  642. attention_mask: Optional[torch.FloatTensor] = None,
  643. token_type_ids: Optional[torch.LongTensor] = None,
  644. position_ids: Optional[torch.LongTensor] = None,
  645. head_mask: Optional[torch.FloatTensor] = None,
  646. inputs_embeds: Optional[torch.FloatTensor] = None,
  647. labels: Optional[torch.LongTensor] = None,
  648. output_attentions: Optional[bool] = None,
  649. output_hidden_states: Optional[bool] = None,
  650. return_dict: Optional[bool] = None,
  651. ) -> Union[tuple, SequenceClassifierOutput]:
  652. r"""
  653. bbox (`torch.LongTensor` of shape `(batch_size, sequence_length, 4)`, *optional*):
  654. Bounding boxes of each input sequence tokens. Selected in the range `[0,
  655. config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1)
  656. format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1,
  657. y1) represents the position of the lower right corner. See [Overview](#Overview) for normalization.
  658. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  659. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  660. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  661. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  662. Examples:
  663. ```python
  664. >>> from transformers import AutoTokenizer, LayoutLMForSequenceClassification
  665. >>> import torch
  666. >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/layoutlm-base-uncased")
  667. >>> model = LayoutLMForSequenceClassification.from_pretrained("microsoft/layoutlm-base-uncased")
  668. >>> words = ["Hello", "world"]
  669. >>> normalized_word_boxes = [637, 773, 693, 782], [698, 773, 733, 782]
  670. >>> token_boxes = []
  671. >>> for word, box in zip(words, normalized_word_boxes):
  672. ... word_tokens = tokenizer.tokenize(word)
  673. ... token_boxes.extend([box] * len(word_tokens))
  674. >>> # add bounding boxes of cls + sep tokens
  675. >>> token_boxes = [[0, 0, 0, 0]] + token_boxes + [[1000, 1000, 1000, 1000]]
  676. >>> encoding = tokenizer(" ".join(words), return_tensors="pt")
  677. >>> input_ids = encoding["input_ids"]
  678. >>> attention_mask = encoding["attention_mask"]
  679. >>> token_type_ids = encoding["token_type_ids"]
  680. >>> bbox = torch.tensor([token_boxes])
  681. >>> sequence_label = torch.tensor([1])
  682. >>> outputs = model(
  683. ... input_ids=input_ids,
  684. ... bbox=bbox,
  685. ... attention_mask=attention_mask,
  686. ... token_type_ids=token_type_ids,
  687. ... labels=sequence_label,
  688. ... )
  689. >>> loss = outputs.loss
  690. >>> logits = outputs.logits
  691. ```"""
  692. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  693. outputs = self.layoutlm(
  694. input_ids=input_ids,
  695. bbox=bbox,
  696. attention_mask=attention_mask,
  697. token_type_ids=token_type_ids,
  698. position_ids=position_ids,
  699. head_mask=head_mask,
  700. inputs_embeds=inputs_embeds,
  701. output_attentions=output_attentions,
  702. output_hidden_states=output_hidden_states,
  703. return_dict=True,
  704. )
  705. pooled_output = outputs[1]
  706. pooled_output = self.dropout(pooled_output)
  707. logits = self.classifier(pooled_output)
  708. loss = None
  709. if labels is not None:
  710. if self.config.problem_type is None:
  711. if self.num_labels == 1:
  712. self.config.problem_type = "regression"
  713. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  714. self.config.problem_type = "single_label_classification"
  715. else:
  716. self.config.problem_type = "multi_label_classification"
  717. if self.config.problem_type == "regression":
  718. loss_fct = MSELoss()
  719. if self.num_labels == 1:
  720. loss = loss_fct(logits.squeeze(), labels.squeeze())
  721. else:
  722. loss = loss_fct(logits, labels)
  723. elif self.config.problem_type == "single_label_classification":
  724. loss_fct = CrossEntropyLoss()
  725. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  726. elif self.config.problem_type == "multi_label_classification":
  727. loss_fct = BCEWithLogitsLoss()
  728. loss = loss_fct(logits, labels)
  729. return SequenceClassifierOutput(
  730. loss=loss,
  731. logits=logits,
  732. hidden_states=outputs.hidden_states,
  733. attentions=outputs.attentions,
  734. )
  735. @auto_docstring(
  736. custom_intro="""
  737. LayoutLM Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
  738. sequence labeling (information extraction) tasks such as the [FUNSD](https://guillaumejaume.github.io/FUNSD/)
  739. dataset and the [SROIE](https://rrc.cvc.uab.es/?ch=13) dataset.
  740. """
  741. )
  742. class LayoutLMForTokenClassification(LayoutLMPreTrainedModel):
  743. def __init__(self, config):
  744. super().__init__(config)
  745. self.num_labels = config.num_labels
  746. self.layoutlm = LayoutLMModel(config)
  747. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  748. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  749. # Initialize weights and apply final processing
  750. self.post_init()
  751. def get_input_embeddings(self):
  752. return self.layoutlm.embeddings.word_embeddings
  753. @can_return_tuple
  754. @auto_docstring
  755. def forward(
  756. self,
  757. input_ids: Optional[torch.LongTensor] = None,
  758. bbox: Optional[torch.LongTensor] = None,
  759. attention_mask: Optional[torch.FloatTensor] = None,
  760. token_type_ids: Optional[torch.LongTensor] = None,
  761. position_ids: Optional[torch.LongTensor] = None,
  762. head_mask: Optional[torch.FloatTensor] = None,
  763. inputs_embeds: Optional[torch.FloatTensor] = None,
  764. labels: Optional[torch.LongTensor] = None,
  765. output_attentions: Optional[bool] = None,
  766. output_hidden_states: Optional[bool] = None,
  767. return_dict: Optional[bool] = None,
  768. ) -> Union[tuple, TokenClassifierOutput]:
  769. r"""
  770. bbox (`torch.LongTensor` of shape `(batch_size, sequence_length, 4)`, *optional*):
  771. Bounding boxes of each input sequence tokens. Selected in the range `[0,
  772. config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1)
  773. format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1,
  774. y1) represents the position of the lower right corner. See [Overview](#Overview) for normalization.
  775. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  776. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  777. Examples:
  778. ```python
  779. >>> from transformers import AutoTokenizer, LayoutLMForTokenClassification
  780. >>> import torch
  781. >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/layoutlm-base-uncased")
  782. >>> model = LayoutLMForTokenClassification.from_pretrained("microsoft/layoutlm-base-uncased")
  783. >>> words = ["Hello", "world"]
  784. >>> normalized_word_boxes = [637, 773, 693, 782], [698, 773, 733, 782]
  785. >>> token_boxes = []
  786. >>> for word, box in zip(words, normalized_word_boxes):
  787. ... word_tokens = tokenizer.tokenize(word)
  788. ... token_boxes.extend([box] * len(word_tokens))
  789. >>> # add bounding boxes of cls + sep tokens
  790. >>> token_boxes = [[0, 0, 0, 0]] + token_boxes + [[1000, 1000, 1000, 1000]]
  791. >>> encoding = tokenizer(" ".join(words), return_tensors="pt")
  792. >>> input_ids = encoding["input_ids"]
  793. >>> attention_mask = encoding["attention_mask"]
  794. >>> token_type_ids = encoding["token_type_ids"]
  795. >>> bbox = torch.tensor([token_boxes])
  796. >>> token_labels = torch.tensor([1, 1, 0, 0]).unsqueeze(0) # batch size of 1
  797. >>> outputs = model(
  798. ... input_ids=input_ids,
  799. ... bbox=bbox,
  800. ... attention_mask=attention_mask,
  801. ... token_type_ids=token_type_ids,
  802. ... labels=token_labels,
  803. ... )
  804. >>> loss = outputs.loss
  805. >>> logits = outputs.logits
  806. ```"""
  807. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  808. outputs = self.layoutlm(
  809. input_ids=input_ids,
  810. bbox=bbox,
  811. attention_mask=attention_mask,
  812. token_type_ids=token_type_ids,
  813. position_ids=position_ids,
  814. head_mask=head_mask,
  815. inputs_embeds=inputs_embeds,
  816. output_attentions=output_attentions,
  817. output_hidden_states=output_hidden_states,
  818. return_dict=True,
  819. )
  820. sequence_output = outputs[0]
  821. sequence_output = self.dropout(sequence_output)
  822. logits = self.classifier(sequence_output)
  823. loss = None
  824. if labels is not None:
  825. loss_fct = CrossEntropyLoss()
  826. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  827. return TokenClassifierOutput(
  828. loss=loss,
  829. logits=logits,
  830. hidden_states=outputs.hidden_states,
  831. attentions=outputs.attentions,
  832. )
  833. @auto_docstring
  834. class LayoutLMForQuestionAnswering(LayoutLMPreTrainedModel):
  835. def __init__(self, config, has_visual_segment_embedding=True):
  836. r"""
  837. has_visual_segment_embedding (`bool`, *optional*, defaults to `True`):
  838. Whether or not to add visual segment embeddings.
  839. """
  840. super().__init__(config)
  841. self.num_labels = config.num_labels
  842. self.layoutlm = LayoutLMModel(config)
  843. self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
  844. # Initialize weights and apply final processing
  845. self.post_init()
  846. def get_input_embeddings(self):
  847. return self.layoutlm.embeddings.word_embeddings
  848. @can_return_tuple
  849. @auto_docstring
  850. def forward(
  851. self,
  852. input_ids: Optional[torch.LongTensor] = None,
  853. bbox: Optional[torch.LongTensor] = None,
  854. attention_mask: Optional[torch.FloatTensor] = None,
  855. token_type_ids: Optional[torch.LongTensor] = None,
  856. position_ids: Optional[torch.LongTensor] = None,
  857. head_mask: Optional[torch.FloatTensor] = None,
  858. inputs_embeds: Optional[torch.FloatTensor] = None,
  859. start_positions: Optional[torch.LongTensor] = None,
  860. end_positions: Optional[torch.LongTensor] = None,
  861. output_attentions: Optional[bool] = None,
  862. output_hidden_states: Optional[bool] = None,
  863. return_dict: Optional[bool] = None,
  864. ) -> Union[tuple, QuestionAnsweringModelOutput]:
  865. r"""
  866. bbox (`torch.LongTensor` of shape `(batch_size, sequence_length, 4)`, *optional*):
  867. Bounding boxes of each input sequence tokens. Selected in the range `[0,
  868. config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1)
  869. format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1,
  870. y1) represents the position of the lower right corner. See [Overview](#Overview) for normalization.
  871. Example:
  872. In the example below, we prepare a question + context pair for the LayoutLM model. It will give us a prediction
  873. of what it thinks the answer is (the span of the answer within the texts parsed from the image).
  874. ```python
  875. >>> from transformers import AutoTokenizer, LayoutLMForQuestionAnswering
  876. >>> from datasets import load_dataset
  877. >>> import torch
  878. >>> tokenizer = AutoTokenizer.from_pretrained("impira/layoutlm-document-qa", add_prefix_space=True)
  879. >>> model = LayoutLMForQuestionAnswering.from_pretrained("impira/layoutlm-document-qa", revision="1e3ebac")
  880. >>> dataset = load_dataset("nielsr/funsd", split="train")
  881. >>> example = dataset[0]
  882. >>> question = "what's his name?"
  883. >>> words = example["words"]
  884. >>> boxes = example["bboxes"]
  885. >>> encoding = tokenizer(
  886. ... question.split(), words, is_split_into_words=True, return_token_type_ids=True, return_tensors="pt"
  887. ... )
  888. >>> bbox = []
  889. >>> for i, s, w in zip(encoding.input_ids[0], encoding.sequence_ids(0), encoding.word_ids(0)):
  890. ... if s == 1:
  891. ... bbox.append(boxes[w])
  892. ... elif i == tokenizer.sep_token_id:
  893. ... bbox.append([1000] * 4)
  894. ... else:
  895. ... bbox.append([0] * 4)
  896. >>> encoding["bbox"] = torch.tensor([bbox])
  897. >>> word_ids = encoding.word_ids(0)
  898. >>> outputs = model(**encoding)
  899. >>> loss = outputs.loss
  900. >>> start_scores = outputs.start_logits
  901. >>> end_scores = outputs.end_logits
  902. >>> start, end = word_ids[start_scores.argmax(-1)], word_ids[end_scores.argmax(-1)]
  903. >>> print(" ".join(words[start : end + 1]))
  904. M. Hamann P. Harper, P. Martinez
  905. ```"""
  906. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  907. outputs = self.layoutlm(
  908. input_ids=input_ids,
  909. bbox=bbox,
  910. attention_mask=attention_mask,
  911. token_type_ids=token_type_ids,
  912. position_ids=position_ids,
  913. head_mask=head_mask,
  914. inputs_embeds=inputs_embeds,
  915. output_attentions=output_attentions,
  916. output_hidden_states=output_hidden_states,
  917. return_dict=True,
  918. )
  919. sequence_output = outputs[0]
  920. logits = self.qa_outputs(sequence_output)
  921. start_logits, end_logits = logits.split(1, dim=-1)
  922. start_logits = start_logits.squeeze(-1).contiguous()
  923. end_logits = end_logits.squeeze(-1).contiguous()
  924. total_loss = None
  925. if start_positions is not None and end_positions is not None:
  926. # If we are on multi-GPU, split add a dimension
  927. if len(start_positions.size()) > 1:
  928. start_positions = start_positions.squeeze(-1)
  929. if len(end_positions.size()) > 1:
  930. end_positions = end_positions.squeeze(-1)
  931. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  932. ignored_index = start_logits.size(1)
  933. start_positions = start_positions.clamp(0, ignored_index)
  934. end_positions = end_positions.clamp(0, ignored_index)
  935. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  936. start_loss = loss_fct(start_logits, start_positions)
  937. end_loss = loss_fct(end_logits, end_positions)
  938. total_loss = (start_loss + end_loss) / 2
  939. return QuestionAnsweringModelOutput(
  940. loss=total_loss,
  941. start_logits=start_logits,
  942. end_logits=end_logits,
  943. hidden_states=outputs.hidden_states,
  944. attentions=outputs.attentions,
  945. )
  946. __all__ = [
  947. "LayoutLMForMaskedLM",
  948. "LayoutLMForSequenceClassification",
  949. "LayoutLMForTokenClassification",
  950. "LayoutLMForQuestionAnswering",
  951. "LayoutLMModel",
  952. "LayoutLMPreTrainedModel",
  953. ]