modeling_layoutlmv3.py 52 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253
  1. # coding=utf-8
  2. # Copyright 2022 Microsoft Research and The HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch LayoutLMv3 model."""
  16. import collections
  17. import math
  18. from typing import Optional, Union
  19. import torch
  20. import torch.nn as nn
  21. import torch.nn.functional as F
  22. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  23. from ...activations import ACT2FN
  24. from ...modeling_layers import GradientCheckpointingLayer
  25. from ...modeling_outputs import (
  26. BaseModelOutput,
  27. QuestionAnsweringModelOutput,
  28. SequenceClassifierOutput,
  29. TokenClassifierOutput,
  30. )
  31. from ...modeling_utils import PreTrainedModel
  32. from ...pytorch_utils import apply_chunking_to_forward
  33. from ...utils import (
  34. auto_docstring,
  35. logging,
  36. torch_int,
  37. )
  38. from .configuration_layoutlmv3 import LayoutLMv3Config
  39. logger = logging.get_logger(__name__)
  40. class LayoutLMv3PatchEmbeddings(nn.Module):
  41. """LayoutLMv3 image (patch) embeddings. This class also automatically interpolates the position embeddings for varying
  42. image sizes."""
  43. def __init__(self, config):
  44. super().__init__()
  45. image_size = (
  46. config.input_size
  47. if isinstance(config.input_size, collections.abc.Iterable)
  48. else (config.input_size, config.input_size)
  49. )
  50. patch_size = (
  51. config.patch_size
  52. if isinstance(config.patch_size, collections.abc.Iterable)
  53. else (config.patch_size, config.patch_size)
  54. )
  55. self.patch_shape = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
  56. self.proj = nn.Conv2d(config.num_channels, config.hidden_size, kernel_size=patch_size, stride=patch_size)
  57. def forward(self, pixel_values, position_embedding=None):
  58. embeddings = self.proj(pixel_values)
  59. if position_embedding is not None:
  60. # interpolate the position embedding to the corresponding size
  61. position_embedding = position_embedding.view(1, self.patch_shape[0], self.patch_shape[1], -1)
  62. position_embedding = position_embedding.permute(0, 3, 1, 2)
  63. patch_height, patch_width = embeddings.shape[2], embeddings.shape[3]
  64. position_embedding = F.interpolate(position_embedding, size=(patch_height, patch_width), mode="bicubic")
  65. embeddings = embeddings + position_embedding
  66. embeddings = embeddings.flatten(2).transpose(1, 2)
  67. return embeddings
  68. class LayoutLMv3TextEmbeddings(nn.Module):
  69. """
  70. LayoutLMv3 text embeddings. Same as `RobertaEmbeddings` but with added spatial (layout) embeddings.
  71. """
  72. def __init__(self, config):
  73. super().__init__()
  74. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  75. self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
  76. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  77. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  78. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  79. self.register_buffer(
  80. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  81. )
  82. self.padding_idx = config.pad_token_id
  83. self.position_embeddings = nn.Embedding(
  84. config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
  85. )
  86. self.x_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.coordinate_size)
  87. self.y_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.coordinate_size)
  88. self.h_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.shape_size)
  89. self.w_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.shape_size)
  90. def calculate_spatial_position_embeddings(self, bbox):
  91. try:
  92. left_position_embeddings = self.x_position_embeddings(bbox[:, :, 0])
  93. upper_position_embeddings = self.y_position_embeddings(bbox[:, :, 1])
  94. right_position_embeddings = self.x_position_embeddings(bbox[:, :, 2])
  95. lower_position_embeddings = self.y_position_embeddings(bbox[:, :, 3])
  96. except IndexError as e:
  97. raise IndexError("The `bbox` coordinate values should be within 0-1000 range.") from e
  98. h_position_embeddings = self.h_position_embeddings(torch.clip(bbox[:, :, 3] - bbox[:, :, 1], 0, 1023))
  99. w_position_embeddings = self.w_position_embeddings(torch.clip(bbox[:, :, 2] - bbox[:, :, 0], 0, 1023))
  100. # below is the difference between LayoutLMEmbeddingsV2 (torch.cat) and LayoutLMEmbeddingsV1 (add)
  101. spatial_position_embeddings = torch.cat(
  102. [
  103. left_position_embeddings,
  104. upper_position_embeddings,
  105. right_position_embeddings,
  106. lower_position_embeddings,
  107. h_position_embeddings,
  108. w_position_embeddings,
  109. ],
  110. dim=-1,
  111. )
  112. return spatial_position_embeddings
  113. def create_position_ids_from_input_ids(self, input_ids, padding_idx):
  114. """
  115. Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding
  116. symbols are ignored. This is modified from fairseq's `utils.make_positions`.
  117. """
  118. # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
  119. mask = input_ids.ne(padding_idx).int()
  120. incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask)) * mask
  121. return incremental_indices.long() + padding_idx
  122. def create_position_ids_from_inputs_embeds(self, inputs_embeds):
  123. """
  124. We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
  125. """
  126. input_shape = inputs_embeds.size()[:-1]
  127. sequence_length = input_shape[1]
  128. position_ids = torch.arange(
  129. self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
  130. )
  131. return position_ids.unsqueeze(0).expand(input_shape)
  132. def forward(
  133. self,
  134. input_ids=None,
  135. bbox=None,
  136. token_type_ids=None,
  137. position_ids=None,
  138. inputs_embeds=None,
  139. ):
  140. if position_ids is None:
  141. if input_ids is not None:
  142. # Create the position ids from the input token ids. Any padded tokens remain padded.
  143. position_ids = self.create_position_ids_from_input_ids(input_ids, self.padding_idx).to(
  144. input_ids.device
  145. )
  146. else:
  147. position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
  148. if input_ids is not None:
  149. input_shape = input_ids.size()
  150. else:
  151. input_shape = inputs_embeds.size()[:-1]
  152. if token_type_ids is None:
  153. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
  154. if inputs_embeds is None:
  155. inputs_embeds = self.word_embeddings(input_ids)
  156. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  157. embeddings = inputs_embeds + token_type_embeddings
  158. position_embeddings = self.position_embeddings(position_ids)
  159. embeddings += position_embeddings
  160. spatial_position_embeddings = self.calculate_spatial_position_embeddings(bbox)
  161. embeddings = embeddings + spatial_position_embeddings
  162. embeddings = self.LayerNorm(embeddings)
  163. embeddings = self.dropout(embeddings)
  164. return embeddings
  165. @auto_docstring
  166. class LayoutLMv3PreTrainedModel(PreTrainedModel):
  167. config: LayoutLMv3Config
  168. base_model_prefix = "layoutlmv3"
  169. def _init_weights(self, module):
  170. """Initialize the weights"""
  171. if isinstance(module, (nn.Linear, nn.Conv2d)):
  172. # Slightly different from the TF version which uses truncated_normal for initialization
  173. # cf https://github.com/pytorch/pytorch/pull/5617
  174. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  175. if module.bias is not None:
  176. module.bias.data.zero_()
  177. elif isinstance(module, nn.Embedding):
  178. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  179. if module.padding_idx is not None:
  180. module.weight.data[module.padding_idx].zero_()
  181. elif isinstance(module, nn.LayerNorm):
  182. module.bias.data.zero_()
  183. module.weight.data.fill_(1.0)
  184. elif isinstance(module, LayoutLMv3Model):
  185. if self.config.visual_embed:
  186. module.cls_token.data.zero_()
  187. module.pos_embed.data.zero_()
  188. class LayoutLMv3SelfAttention(nn.Module):
  189. def __init__(self, config):
  190. super().__init__()
  191. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  192. raise ValueError(
  193. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  194. f"heads ({config.num_attention_heads})"
  195. )
  196. self.num_attention_heads = config.num_attention_heads
  197. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  198. self.all_head_size = self.num_attention_heads * self.attention_head_size
  199. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  200. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  201. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  202. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  203. self.has_relative_attention_bias = config.has_relative_attention_bias
  204. self.has_spatial_attention_bias = config.has_spatial_attention_bias
  205. def cogview_attention(self, attention_scores, alpha=32):
  206. """
  207. https://huggingface.co/papers/2105.13290 Section 2.4 Stabilization of training: Precision Bottleneck Relaxation
  208. (PB-Relax). A replacement of the original nn.Softmax(dim=-1)(attention_scores). Seems the new attention_probs
  209. will result in a slower speed and a little bias. Can use torch.allclose(standard_attention_probs,
  210. cogview_attention_probs, atol=1e-08) for comparison. The smaller atol (e.g., 1e-08), the better.
  211. """
  212. scaled_attention_scores = attention_scores / alpha
  213. max_value = scaled_attention_scores.amax(dim=(-1)).unsqueeze(-1)
  214. new_attention_scores = (scaled_attention_scores - max_value) * alpha
  215. return nn.Softmax(dim=-1)(new_attention_scores)
  216. def forward(
  217. self,
  218. hidden_states,
  219. attention_mask=None,
  220. head_mask=None,
  221. output_attentions=False,
  222. rel_pos=None,
  223. rel_2d_pos=None,
  224. ):
  225. batch_size, seq_length, _ = hidden_states.shape
  226. query_layer = (
  227. self.query(hidden_states)
  228. .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
  229. .transpose(1, 2)
  230. )
  231. key_layer = (
  232. self.key(hidden_states)
  233. .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
  234. .transpose(1, 2)
  235. )
  236. value_layer = (
  237. self.value(hidden_states)
  238. .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
  239. .transpose(1, 2)
  240. )
  241. # Take the dot product between "query" and "key" to get the raw attention scores.
  242. # The attention scores QT K/√d could be significantly larger than input elements, and result in overflow.
  243. # Changing the computational order into QT(K/√d) alleviates the problem. (https://huggingface.co/papers/2105.13290)
  244. attention_scores = torch.matmul(query_layer / math.sqrt(self.attention_head_size), key_layer.transpose(-1, -2))
  245. if self.has_relative_attention_bias and self.has_spatial_attention_bias:
  246. attention_scores += (rel_pos + rel_2d_pos) / math.sqrt(self.attention_head_size)
  247. elif self.has_relative_attention_bias:
  248. attention_scores += rel_pos / math.sqrt(self.attention_head_size)
  249. if attention_mask is not None:
  250. # Apply the attention mask is (precomputed for all layers in RobertaModel forward() function)
  251. attention_scores = attention_scores + attention_mask
  252. # Normalize the attention scores to probabilities.
  253. # Use the trick of the CogView paper to stabilize training
  254. attention_probs = self.cogview_attention(attention_scores)
  255. # This is actually dropping out entire tokens to attend to, which might
  256. # seem a bit unusual, but is taken from the original Transformer paper.
  257. attention_probs = self.dropout(attention_probs)
  258. # Mask heads if we want to
  259. if head_mask is not None:
  260. attention_probs = attention_probs * head_mask
  261. context_layer = torch.matmul(attention_probs, value_layer)
  262. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  263. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  264. context_layer = context_layer.view(*new_context_layer_shape)
  265. outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
  266. return outputs
  267. # Copied from transformers.models.roberta.modeling_roberta.RobertaSelfOutput
  268. class LayoutLMv3SelfOutput(nn.Module):
  269. def __init__(self, config):
  270. super().__init__()
  271. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  272. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  273. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  274. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  275. hidden_states = self.dense(hidden_states)
  276. hidden_states = self.dropout(hidden_states)
  277. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  278. return hidden_states
  279. # Copied from transformers.models.layoutlmv2.modeling_layoutlmv2.LayoutLMv2Attention with LayoutLMv2->LayoutLMv3
  280. class LayoutLMv3Attention(nn.Module):
  281. def __init__(self, config):
  282. super().__init__()
  283. self.self = LayoutLMv3SelfAttention(config)
  284. self.output = LayoutLMv3SelfOutput(config)
  285. def forward(
  286. self,
  287. hidden_states,
  288. attention_mask=None,
  289. head_mask=None,
  290. output_attentions=False,
  291. rel_pos=None,
  292. rel_2d_pos=None,
  293. ):
  294. self_outputs = self.self(
  295. hidden_states,
  296. attention_mask,
  297. head_mask,
  298. output_attentions,
  299. rel_pos=rel_pos,
  300. rel_2d_pos=rel_2d_pos,
  301. )
  302. attention_output = self.output(self_outputs[0], hidden_states)
  303. outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
  304. return outputs
  305. # Copied from transformers.models.layoutlmv2.modeling_layoutlmv2.LayoutLMv2Layer with LayoutLMv2->LayoutLMv3
  306. class LayoutLMv3Layer(GradientCheckpointingLayer):
  307. def __init__(self, config):
  308. super().__init__()
  309. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  310. self.seq_len_dim = 1
  311. self.attention = LayoutLMv3Attention(config)
  312. self.intermediate = LayoutLMv3Intermediate(config)
  313. self.output = LayoutLMv3Output(config)
  314. def forward(
  315. self,
  316. hidden_states,
  317. attention_mask=None,
  318. head_mask=None,
  319. output_attentions=False,
  320. rel_pos=None,
  321. rel_2d_pos=None,
  322. ):
  323. self_attention_outputs = self.attention(
  324. hidden_states,
  325. attention_mask,
  326. head_mask,
  327. output_attentions=output_attentions,
  328. rel_pos=rel_pos,
  329. rel_2d_pos=rel_2d_pos,
  330. )
  331. attention_output = self_attention_outputs[0]
  332. outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
  333. layer_output = apply_chunking_to_forward(
  334. self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
  335. )
  336. outputs = (layer_output,) + outputs
  337. return outputs
  338. def feed_forward_chunk(self, attention_output):
  339. intermediate_output = self.intermediate(attention_output)
  340. layer_output = self.output(intermediate_output, attention_output)
  341. return layer_output
  342. class LayoutLMv3Encoder(nn.Module):
  343. def __init__(self, config):
  344. super().__init__()
  345. self.config = config
  346. self.layer = nn.ModuleList([LayoutLMv3Layer(config) for _ in range(config.num_hidden_layers)])
  347. self.gradient_checkpointing = False
  348. self.has_relative_attention_bias = config.has_relative_attention_bias
  349. self.has_spatial_attention_bias = config.has_spatial_attention_bias
  350. if self.has_relative_attention_bias:
  351. self.rel_pos_bins = config.rel_pos_bins
  352. self.max_rel_pos = config.max_rel_pos
  353. self.rel_pos_bias = nn.Linear(self.rel_pos_bins, config.num_attention_heads, bias=False)
  354. if self.has_spatial_attention_bias:
  355. self.max_rel_2d_pos = config.max_rel_2d_pos
  356. self.rel_2d_pos_bins = config.rel_2d_pos_bins
  357. self.rel_pos_x_bias = nn.Linear(self.rel_2d_pos_bins, config.num_attention_heads, bias=False)
  358. self.rel_pos_y_bias = nn.Linear(self.rel_2d_pos_bins, config.num_attention_heads, bias=False)
  359. def relative_position_bucket(self, relative_position, bidirectional=True, num_buckets=32, max_distance=128):
  360. ret = 0
  361. if bidirectional:
  362. num_buckets //= 2
  363. ret += (relative_position > 0).long() * num_buckets
  364. n = torch.abs(relative_position)
  365. else:
  366. n = torch.max(-relative_position, torch.zeros_like(relative_position))
  367. # now n is in the range [0, inf)
  368. # half of the buckets are for exact increments in positions
  369. max_exact = num_buckets // 2
  370. is_small = n < max_exact
  371. # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
  372. val_if_large = max_exact + (
  373. torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
  374. ).to(torch.long)
  375. val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
  376. ret += torch.where(is_small, n, val_if_large)
  377. return ret
  378. def _cal_1d_pos_emb(self, position_ids):
  379. rel_pos_mat = position_ids.unsqueeze(-2) - position_ids.unsqueeze(-1)
  380. rel_pos = self.relative_position_bucket(
  381. rel_pos_mat,
  382. num_buckets=self.rel_pos_bins,
  383. max_distance=self.max_rel_pos,
  384. )
  385. # Since this is a simple indexing operation that is independent of the input,
  386. # no need to track gradients for this operation
  387. #
  388. # Without this no_grad context, training speed slows down significantly
  389. with torch.no_grad():
  390. rel_pos = self.rel_pos_bias.weight.t()[rel_pos].permute(0, 3, 1, 2)
  391. rel_pos = rel_pos.contiguous()
  392. return rel_pos
  393. def _cal_2d_pos_emb(self, bbox):
  394. position_coord_x = bbox[:, :, 0]
  395. position_coord_y = bbox[:, :, 3]
  396. rel_pos_x_2d_mat = position_coord_x.unsqueeze(-2) - position_coord_x.unsqueeze(-1)
  397. rel_pos_y_2d_mat = position_coord_y.unsqueeze(-2) - position_coord_y.unsqueeze(-1)
  398. rel_pos_x = self.relative_position_bucket(
  399. rel_pos_x_2d_mat,
  400. num_buckets=self.rel_2d_pos_bins,
  401. max_distance=self.max_rel_2d_pos,
  402. )
  403. rel_pos_y = self.relative_position_bucket(
  404. rel_pos_y_2d_mat,
  405. num_buckets=self.rel_2d_pos_bins,
  406. max_distance=self.max_rel_2d_pos,
  407. )
  408. # Since this is a simple indexing operation that is independent of the input,
  409. # no need to track gradients for this operation
  410. #
  411. # Without this no_grad context, training speed slows down significantly
  412. with torch.no_grad():
  413. rel_pos_x = self.rel_pos_x_bias.weight.t()[rel_pos_x].permute(0, 3, 1, 2)
  414. rel_pos_y = self.rel_pos_y_bias.weight.t()[rel_pos_y].permute(0, 3, 1, 2)
  415. rel_pos_x = rel_pos_x.contiguous()
  416. rel_pos_y = rel_pos_y.contiguous()
  417. rel_2d_pos = rel_pos_x + rel_pos_y
  418. return rel_2d_pos
  419. def forward(
  420. self,
  421. hidden_states,
  422. bbox=None,
  423. attention_mask=None,
  424. head_mask=None,
  425. output_attentions=False,
  426. output_hidden_states=False,
  427. return_dict=True,
  428. position_ids=None,
  429. patch_height=None,
  430. patch_width=None,
  431. ):
  432. all_hidden_states = () if output_hidden_states else None
  433. all_self_attentions = () if output_attentions else None
  434. rel_pos = self._cal_1d_pos_emb(position_ids) if self.has_relative_attention_bias else None
  435. rel_2d_pos = self._cal_2d_pos_emb(bbox) if self.has_spatial_attention_bias else None
  436. for i, layer_module in enumerate(self.layer):
  437. if output_hidden_states:
  438. all_hidden_states = all_hidden_states + (hidden_states,)
  439. layer_head_mask = head_mask[i] if head_mask is not None else None
  440. layer_outputs = layer_module(
  441. hidden_states,
  442. attention_mask,
  443. layer_head_mask,
  444. output_attentions,
  445. rel_pos=rel_pos,
  446. rel_2d_pos=rel_2d_pos,
  447. )
  448. hidden_states = layer_outputs[0]
  449. if output_attentions:
  450. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  451. if output_hidden_states:
  452. all_hidden_states = all_hidden_states + (hidden_states,)
  453. if not return_dict:
  454. return tuple(
  455. v
  456. for v in [
  457. hidden_states,
  458. all_hidden_states,
  459. all_self_attentions,
  460. ]
  461. if v is not None
  462. )
  463. return BaseModelOutput(
  464. last_hidden_state=hidden_states,
  465. hidden_states=all_hidden_states,
  466. attentions=all_self_attentions,
  467. )
  468. # Copied from transformers.models.roberta.modeling_roberta.RobertaIntermediate
  469. class LayoutLMv3Intermediate(nn.Module):
  470. def __init__(self, config):
  471. super().__init__()
  472. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  473. if isinstance(config.hidden_act, str):
  474. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  475. else:
  476. self.intermediate_act_fn = config.hidden_act
  477. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  478. hidden_states = self.dense(hidden_states)
  479. hidden_states = self.intermediate_act_fn(hidden_states)
  480. return hidden_states
  481. # Copied from transformers.models.roberta.modeling_roberta.RobertaOutput
  482. class LayoutLMv3Output(nn.Module):
  483. def __init__(self, config):
  484. super().__init__()
  485. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  486. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  487. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  488. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  489. hidden_states = self.dense(hidden_states)
  490. hidden_states = self.dropout(hidden_states)
  491. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  492. return hidden_states
  493. @auto_docstring
  494. class LayoutLMv3Model(LayoutLMv3PreTrainedModel):
  495. def __init__(self, config):
  496. super().__init__(config)
  497. self.config = config
  498. if config.text_embed:
  499. self.embeddings = LayoutLMv3TextEmbeddings(config)
  500. if config.visual_embed:
  501. # use the default pre-training parameters for fine-tuning (e.g., input_size)
  502. # when the input_size is larger in fine-tuning, we will interpolate the position embeddings in forward
  503. self.patch_embed = LayoutLMv3PatchEmbeddings(config)
  504. size = int(config.input_size / config.patch_size)
  505. self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
  506. self.pos_embed = nn.Parameter(torch.zeros(1, size * size + 1, config.hidden_size))
  507. self.pos_drop = nn.Dropout(p=0.0)
  508. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  509. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  510. if self.config.has_relative_attention_bias or self.config.has_spatial_attention_bias:
  511. self.init_visual_bbox(image_size=(size, size))
  512. self.norm = nn.LayerNorm(config.hidden_size, eps=1e-6)
  513. self.encoder = LayoutLMv3Encoder(config)
  514. self.init_weights()
  515. def get_input_embeddings(self):
  516. return self.embeddings.word_embeddings
  517. def set_input_embeddings(self, value):
  518. self.embeddings.word_embeddings = value
  519. def _prune_heads(self, heads_to_prune):
  520. """
  521. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  522. class PreTrainedModel
  523. """
  524. for layer, heads in heads_to_prune.items():
  525. self.encoder.layer[layer].attention.prune_heads(heads)
  526. def init_visual_bbox(self, image_size=(14, 14), max_len=1000):
  527. """
  528. Create the bounding boxes for the visual (patch) tokens.
  529. """
  530. visual_bbox_x = torch.div(
  531. torch.arange(0, max_len * (image_size[1] + 1), max_len), image_size[1], rounding_mode="trunc"
  532. )
  533. visual_bbox_y = torch.div(
  534. torch.arange(0, max_len * (image_size[0] + 1), max_len), image_size[0], rounding_mode="trunc"
  535. )
  536. visual_bbox = torch.stack(
  537. [
  538. visual_bbox_x[:-1].repeat(image_size[0], 1),
  539. visual_bbox_y[:-1].repeat(image_size[1], 1).transpose(0, 1),
  540. visual_bbox_x[1:].repeat(image_size[0], 1),
  541. visual_bbox_y[1:].repeat(image_size[1], 1).transpose(0, 1),
  542. ],
  543. dim=-1,
  544. ).view(-1, 4)
  545. cls_token_box = torch.tensor([[0 + 1, 0 + 1, max_len - 1, max_len - 1]])
  546. self.visual_bbox = torch.cat([cls_token_box, visual_bbox], dim=0)
  547. def calculate_visual_bbox(self, device, dtype, batch_size):
  548. visual_bbox = self.visual_bbox.repeat(batch_size, 1, 1)
  549. visual_bbox = visual_bbox.to(device).type(dtype)
  550. return visual_bbox
  551. def forward_image(self, pixel_values):
  552. embeddings = self.patch_embed(pixel_values)
  553. # add [CLS] token
  554. batch_size, seq_len, _ = embeddings.size()
  555. cls_tokens = self.cls_token.expand(batch_size, -1, -1)
  556. embeddings = torch.cat((cls_tokens, embeddings), dim=1)
  557. # add position embeddings
  558. if self.pos_embed is not None:
  559. embeddings = embeddings + self.pos_embed
  560. embeddings = self.pos_drop(embeddings)
  561. embeddings = self.norm(embeddings)
  562. return embeddings
  563. @auto_docstring
  564. def forward(
  565. self,
  566. input_ids: Optional[torch.LongTensor] = None,
  567. bbox: Optional[torch.LongTensor] = None,
  568. attention_mask: Optional[torch.FloatTensor] = None,
  569. token_type_ids: Optional[torch.LongTensor] = None,
  570. position_ids: Optional[torch.LongTensor] = None,
  571. head_mask: Optional[torch.FloatTensor] = None,
  572. inputs_embeds: Optional[torch.FloatTensor] = None,
  573. pixel_values: Optional[torch.FloatTensor] = None,
  574. output_attentions: Optional[bool] = None,
  575. output_hidden_states: Optional[bool] = None,
  576. return_dict: Optional[bool] = None,
  577. ) -> Union[tuple, BaseModelOutput]:
  578. r"""
  579. input_ids (`torch.LongTensor` of shape `(batch_size, token_sequence_length)`):
  580. Indices of input sequence tokens in the vocabulary.
  581. Note that `sequence_length = token_sequence_length + patch_sequence_length + 1` where `1` is for [CLS]
  582. token. See `pixel_values` for `patch_sequence_length`.
  583. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  584. [`PreTrainedTokenizer.__call__`] for details.
  585. [What are input IDs?](../glossary#input-ids)
  586. bbox (`torch.LongTensor` of shape `(batch_size, token_sequence_length, 4)`, *optional*):
  587. Bounding boxes of each input sequence tokens. Selected in the range `[0,
  588. config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1)
  589. format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1,
  590. y1) represents the position of the lower right corner.
  591. Note that `sequence_length = token_sequence_length + patch_sequence_length + 1` where `1` is for [CLS]
  592. token. See `pixel_values` for `patch_sequence_length`.
  593. token_type_ids (`torch.LongTensor` of shape `(batch_size, token_sequence_length)`, *optional*):
  594. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
  595. 1]`:
  596. - 0 corresponds to a *sentence A* token,
  597. - 1 corresponds to a *sentence B* token.
  598. Note that `sequence_length = token_sequence_length + patch_sequence_length + 1` where `1` is for [CLS]
  599. token. See `pixel_values` for `patch_sequence_length`.
  600. [What are token type IDs?](../glossary#token-type-ids)
  601. position_ids (`torch.LongTensor` of shape `(batch_size, token_sequence_length)`, *optional*):
  602. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  603. config.max_position_embeddings - 1]`.
  604. Note that `sequence_length = token_sequence_length + patch_sequence_length + 1` where `1` is for [CLS]
  605. token. See `pixel_values` for `patch_sequence_length`.
  606. [What are position IDs?](../glossary#position-ids)
  607. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, token_sequence_length, hidden_size)`, *optional*):
  608. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  609. is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
  610. model's internal embedding lookup matrix.
  611. Examples:
  612. ```python
  613. >>> from transformers import AutoProcessor, AutoModel
  614. >>> from datasets import load_dataset
  615. >>> processor = AutoProcessor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False)
  616. >>> model = AutoModel.from_pretrained("microsoft/layoutlmv3-base")
  617. >>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train")
  618. >>> example = dataset[0]
  619. >>> image = example["image"]
  620. >>> words = example["tokens"]
  621. >>> boxes = example["bboxes"]
  622. >>> encoding = processor(image, words, boxes=boxes, return_tensors="pt")
  623. >>> outputs = model(**encoding)
  624. >>> last_hidden_states = outputs.last_hidden_state
  625. ```"""
  626. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  627. output_hidden_states = (
  628. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  629. )
  630. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  631. if input_ids is not None:
  632. input_shape = input_ids.size()
  633. batch_size, seq_length = input_shape
  634. device = input_ids.device
  635. elif inputs_embeds is not None:
  636. input_shape = inputs_embeds.size()[:-1]
  637. batch_size, seq_length = input_shape
  638. device = inputs_embeds.device
  639. elif pixel_values is not None:
  640. batch_size = len(pixel_values)
  641. device = pixel_values.device
  642. else:
  643. raise ValueError("You have to specify either input_ids or inputs_embeds or pixel_values")
  644. if input_ids is not None or inputs_embeds is not None:
  645. if attention_mask is None:
  646. attention_mask = torch.ones(((batch_size, seq_length)), device=device)
  647. if token_type_ids is None:
  648. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
  649. if bbox is None:
  650. bbox = torch.zeros(tuple(list(input_shape) + [4]), dtype=torch.long, device=device)
  651. embedding_output = self.embeddings(
  652. input_ids=input_ids,
  653. bbox=bbox,
  654. position_ids=position_ids,
  655. token_type_ids=token_type_ids,
  656. inputs_embeds=inputs_embeds,
  657. )
  658. final_bbox = final_position_ids = None
  659. patch_height = patch_width = None
  660. if pixel_values is not None:
  661. patch_height, patch_width = (
  662. torch_int(pixel_values.shape[2] / self.config.patch_size),
  663. torch_int(pixel_values.shape[3] / self.config.patch_size),
  664. )
  665. visual_embeddings = self.forward_image(pixel_values)
  666. visual_attention_mask = torch.ones(
  667. (batch_size, visual_embeddings.shape[1]), dtype=torch.long, device=device
  668. )
  669. if attention_mask is not None:
  670. attention_mask = torch.cat([attention_mask, visual_attention_mask], dim=1)
  671. else:
  672. attention_mask = visual_attention_mask
  673. if self.config.has_relative_attention_bias or self.config.has_spatial_attention_bias:
  674. if self.config.has_spatial_attention_bias:
  675. visual_bbox = self.calculate_visual_bbox(device, dtype=torch.long, batch_size=batch_size)
  676. if bbox is not None:
  677. final_bbox = torch.cat([bbox, visual_bbox], dim=1)
  678. else:
  679. final_bbox = visual_bbox
  680. visual_position_ids = torch.arange(
  681. 0, visual_embeddings.shape[1], dtype=torch.long, device=device
  682. ).repeat(batch_size, 1)
  683. if input_ids is not None or inputs_embeds is not None:
  684. position_ids = torch.arange(0, input_shape[1], device=device).unsqueeze(0)
  685. position_ids = position_ids.expand(input_shape)
  686. final_position_ids = torch.cat([position_ids, visual_position_ids], dim=1)
  687. else:
  688. final_position_ids = visual_position_ids
  689. if input_ids is not None or inputs_embeds is not None:
  690. embedding_output = torch.cat([embedding_output, visual_embeddings], dim=1)
  691. else:
  692. embedding_output = visual_embeddings
  693. embedding_output = self.LayerNorm(embedding_output)
  694. embedding_output = self.dropout(embedding_output)
  695. elif self.config.has_relative_attention_bias or self.config.has_spatial_attention_bias:
  696. if self.config.has_spatial_attention_bias:
  697. final_bbox = bbox
  698. if self.config.has_relative_attention_bias:
  699. position_ids = self.embeddings.position_ids[:, : input_shape[1]]
  700. position_ids = position_ids.expand_as(input_ids)
  701. final_position_ids = position_ids
  702. extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
  703. attention_mask, None, device, dtype=embedding_output.dtype
  704. )
  705. # Prepare head mask if needed
  706. # 1.0 in head_mask indicate we keep the head
  707. # attention_probs has shape bsz x n_heads x N x N
  708. # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
  709. # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
  710. head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
  711. encoder_outputs = self.encoder(
  712. embedding_output,
  713. bbox=final_bbox,
  714. position_ids=final_position_ids,
  715. attention_mask=extended_attention_mask,
  716. head_mask=head_mask,
  717. output_attentions=output_attentions,
  718. output_hidden_states=output_hidden_states,
  719. return_dict=return_dict,
  720. patch_height=patch_height,
  721. patch_width=patch_width,
  722. )
  723. sequence_output = encoder_outputs[0]
  724. if not return_dict:
  725. return (sequence_output,) + encoder_outputs[1:]
  726. return BaseModelOutput(
  727. last_hidden_state=sequence_output,
  728. hidden_states=encoder_outputs.hidden_states,
  729. attentions=encoder_outputs.attentions,
  730. )
  731. class LayoutLMv3ClassificationHead(nn.Module):
  732. """
  733. Head for sentence-level classification tasks. Reference: RobertaClassificationHead
  734. """
  735. def __init__(self, config, pool_feature=False):
  736. super().__init__()
  737. self.pool_feature = pool_feature
  738. if pool_feature:
  739. self.dense = nn.Linear(config.hidden_size * 3, config.hidden_size)
  740. else:
  741. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  742. classifier_dropout = (
  743. config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
  744. )
  745. self.dropout = nn.Dropout(classifier_dropout)
  746. self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
  747. def forward(self, x):
  748. x = self.dropout(x)
  749. x = self.dense(x)
  750. x = torch.tanh(x)
  751. x = self.dropout(x)
  752. x = self.out_proj(x)
  753. return x
  754. @auto_docstring(
  755. custom_intro="""
  756. LayoutLMv3 Model with a token classification head on top (a linear layer on top of the final hidden states) e.g.
  757. for sequence labeling (information extraction) tasks such as [FUNSD](https://guillaumejaume.github.io/FUNSD/),
  758. [SROIE](https://rrc.cvc.uab.es/?ch=13), [CORD](https://github.com/clovaai/cord) and
  759. [Kleister-NDA](https://github.com/applicaai/kleister-nda).
  760. """
  761. )
  762. class LayoutLMv3ForTokenClassification(LayoutLMv3PreTrainedModel):
  763. def __init__(self, config):
  764. super().__init__(config)
  765. self.num_labels = config.num_labels
  766. self.layoutlmv3 = LayoutLMv3Model(config)
  767. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  768. if config.num_labels < 10:
  769. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  770. else:
  771. self.classifier = LayoutLMv3ClassificationHead(config, pool_feature=False)
  772. self.init_weights()
  773. @auto_docstring
  774. def forward(
  775. self,
  776. input_ids: Optional[torch.LongTensor] = None,
  777. bbox: Optional[torch.LongTensor] = None,
  778. attention_mask: Optional[torch.FloatTensor] = None,
  779. token_type_ids: Optional[torch.LongTensor] = None,
  780. position_ids: Optional[torch.LongTensor] = None,
  781. head_mask: Optional[torch.FloatTensor] = None,
  782. inputs_embeds: Optional[torch.FloatTensor] = None,
  783. labels: Optional[torch.LongTensor] = None,
  784. output_attentions: Optional[bool] = None,
  785. output_hidden_states: Optional[bool] = None,
  786. return_dict: Optional[bool] = None,
  787. pixel_values: Optional[torch.LongTensor] = None,
  788. ) -> Union[tuple, TokenClassifierOutput]:
  789. r"""
  790. bbox (`torch.LongTensor` of shape `(batch_size, sequence_length, 4)`, *optional*):
  791. Bounding boxes of each input sequence tokens. Selected in the range `[0,
  792. config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1)
  793. format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1,
  794. y1) represents the position of the lower right corner.
  795. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  796. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  797. Examples:
  798. ```python
  799. >>> from transformers import AutoProcessor, AutoModelForTokenClassification
  800. >>> from datasets import load_dataset
  801. >>> processor = AutoProcessor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False)
  802. >>> model = AutoModelForTokenClassification.from_pretrained("microsoft/layoutlmv3-base", num_labels=7)
  803. >>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train")
  804. >>> example = dataset[0]
  805. >>> image = example["image"]
  806. >>> words = example["tokens"]
  807. >>> boxes = example["bboxes"]
  808. >>> word_labels = example["ner_tags"]
  809. >>> encoding = processor(image, words, boxes=boxes, word_labels=word_labels, return_tensors="pt")
  810. >>> outputs = model(**encoding)
  811. >>> loss = outputs.loss
  812. >>> logits = outputs.logits
  813. ```"""
  814. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  815. outputs = self.layoutlmv3(
  816. input_ids,
  817. bbox=bbox,
  818. attention_mask=attention_mask,
  819. token_type_ids=token_type_ids,
  820. position_ids=position_ids,
  821. head_mask=head_mask,
  822. inputs_embeds=inputs_embeds,
  823. output_attentions=output_attentions,
  824. output_hidden_states=output_hidden_states,
  825. return_dict=return_dict,
  826. pixel_values=pixel_values,
  827. )
  828. if input_ids is not None:
  829. input_shape = input_ids.size()
  830. else:
  831. input_shape = inputs_embeds.size()[:-1]
  832. seq_length = input_shape[1]
  833. # only take the text part of the output representations
  834. sequence_output = outputs[0][:, :seq_length]
  835. sequence_output = self.dropout(sequence_output)
  836. logits = self.classifier(sequence_output)
  837. loss = None
  838. if labels is not None:
  839. loss_fct = CrossEntropyLoss()
  840. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  841. if not return_dict:
  842. output = (logits,) + outputs[1:]
  843. return ((loss,) + output) if loss is not None else output
  844. return TokenClassifierOutput(
  845. loss=loss,
  846. logits=logits,
  847. hidden_states=outputs.hidden_states,
  848. attentions=outputs.attentions,
  849. )
  850. @auto_docstring
  851. class LayoutLMv3ForQuestionAnswering(LayoutLMv3PreTrainedModel):
  852. def __init__(self, config):
  853. super().__init__(config)
  854. self.num_labels = config.num_labels
  855. self.layoutlmv3 = LayoutLMv3Model(config)
  856. self.qa_outputs = LayoutLMv3ClassificationHead(config, pool_feature=False)
  857. self.init_weights()
  858. @auto_docstring
  859. def forward(
  860. self,
  861. input_ids: Optional[torch.LongTensor] = None,
  862. attention_mask: Optional[torch.FloatTensor] = None,
  863. token_type_ids: Optional[torch.LongTensor] = None,
  864. position_ids: Optional[torch.LongTensor] = None,
  865. head_mask: Optional[torch.FloatTensor] = None,
  866. inputs_embeds: Optional[torch.FloatTensor] = None,
  867. start_positions: Optional[torch.LongTensor] = None,
  868. end_positions: Optional[torch.LongTensor] = None,
  869. output_attentions: Optional[bool] = None,
  870. output_hidden_states: Optional[bool] = None,
  871. return_dict: Optional[bool] = None,
  872. bbox: Optional[torch.LongTensor] = None,
  873. pixel_values: Optional[torch.LongTensor] = None,
  874. ) -> Union[tuple, QuestionAnsweringModelOutput]:
  875. r"""
  876. bbox (`torch.LongTensor` of shape `(batch_size, sequence_length, 4)`, *optional*):
  877. Bounding boxes of each input sequence tokens. Selected in the range `[0,
  878. config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1)
  879. format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1,
  880. y1) represents the position of the lower right corner.
  881. Examples:
  882. ```python
  883. >>> from transformers import AutoProcessor, AutoModelForQuestionAnswering
  884. >>> from datasets import load_dataset
  885. >>> import torch
  886. >>> processor = AutoProcessor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False)
  887. >>> model = AutoModelForQuestionAnswering.from_pretrained("microsoft/layoutlmv3-base")
  888. >>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train")
  889. >>> example = dataset[0]
  890. >>> image = example["image"]
  891. >>> question = "what's his name?"
  892. >>> words = example["tokens"]
  893. >>> boxes = example["bboxes"]
  894. >>> encoding = processor(image, question, words, boxes=boxes, return_tensors="pt")
  895. >>> start_positions = torch.tensor([1])
  896. >>> end_positions = torch.tensor([3])
  897. >>> outputs = model(**encoding, start_positions=start_positions, end_positions=end_positions)
  898. >>> loss = outputs.loss
  899. >>> start_scores = outputs.start_logits
  900. >>> end_scores = outputs.end_logits
  901. ```"""
  902. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  903. outputs = self.layoutlmv3(
  904. input_ids,
  905. attention_mask=attention_mask,
  906. token_type_ids=token_type_ids,
  907. position_ids=position_ids,
  908. head_mask=head_mask,
  909. inputs_embeds=inputs_embeds,
  910. output_attentions=output_attentions,
  911. output_hidden_states=output_hidden_states,
  912. return_dict=return_dict,
  913. bbox=bbox,
  914. pixel_values=pixel_values,
  915. )
  916. sequence_output = outputs[0]
  917. logits = self.qa_outputs(sequence_output)
  918. start_logits, end_logits = logits.split(1, dim=-1)
  919. start_logits = start_logits.squeeze(-1).contiguous()
  920. end_logits = end_logits.squeeze(-1).contiguous()
  921. total_loss = None
  922. if start_positions is not None and end_positions is not None:
  923. # If we are on multi-GPU, split add a dimension
  924. if len(start_positions.size()) > 1:
  925. start_positions = start_positions.squeeze(-1)
  926. if len(end_positions.size()) > 1:
  927. end_positions = end_positions.squeeze(-1)
  928. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  929. ignored_index = start_logits.size(1)
  930. start_positions = start_positions.clamp(0, ignored_index)
  931. end_positions = end_positions.clamp(0, ignored_index)
  932. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  933. start_loss = loss_fct(start_logits, start_positions)
  934. end_loss = loss_fct(end_logits, end_positions)
  935. total_loss = (start_loss + end_loss) / 2
  936. if not return_dict:
  937. output = (start_logits, end_logits) + outputs[1:]
  938. return ((total_loss,) + output) if total_loss is not None else output
  939. return QuestionAnsweringModelOutput(
  940. loss=total_loss,
  941. start_logits=start_logits,
  942. end_logits=end_logits,
  943. hidden_states=outputs.hidden_states,
  944. attentions=outputs.attentions,
  945. )
  946. @auto_docstring(
  947. custom_intro="""
  948. LayoutLMv3 Model with a sequence classification head on top (a linear layer on top of the final hidden state of the
  949. [CLS] token) e.g. for document image classification tasks such as the
  950. [RVL-CDIP](https://www.cs.cmu.edu/~aharley/rvl-cdip/) dataset.
  951. """
  952. )
  953. class LayoutLMv3ForSequenceClassification(LayoutLMv3PreTrainedModel):
  954. def __init__(self, config):
  955. super().__init__(config)
  956. self.num_labels = config.num_labels
  957. self.config = config
  958. self.layoutlmv3 = LayoutLMv3Model(config)
  959. self.classifier = LayoutLMv3ClassificationHead(config, pool_feature=False)
  960. self.init_weights()
  961. @auto_docstring
  962. def forward(
  963. self,
  964. input_ids: Optional[torch.LongTensor] = None,
  965. attention_mask: Optional[torch.FloatTensor] = None,
  966. token_type_ids: Optional[torch.LongTensor] = None,
  967. position_ids: Optional[torch.LongTensor] = None,
  968. head_mask: Optional[torch.FloatTensor] = None,
  969. inputs_embeds: Optional[torch.FloatTensor] = None,
  970. labels: Optional[torch.LongTensor] = None,
  971. output_attentions: Optional[bool] = None,
  972. output_hidden_states: Optional[bool] = None,
  973. return_dict: Optional[bool] = None,
  974. bbox: Optional[torch.LongTensor] = None,
  975. pixel_values: Optional[torch.LongTensor] = None,
  976. ) -> Union[tuple, SequenceClassifierOutput]:
  977. r"""
  978. bbox (`torch.LongTensor` of shape `(batch_size, sequence_length, 4)`, *optional*):
  979. Bounding boxes of each input sequence tokens. Selected in the range `[0,
  980. config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1)
  981. format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1,
  982. y1) represents the position of the lower right corner.
  983. Examples:
  984. ```python
  985. >>> from transformers import AutoProcessor, AutoModelForSequenceClassification
  986. >>> from datasets import load_dataset
  987. >>> import torch
  988. >>> processor = AutoProcessor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False)
  989. >>> model = AutoModelForSequenceClassification.from_pretrained("microsoft/layoutlmv3-base")
  990. >>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train")
  991. >>> example = dataset[0]
  992. >>> image = example["image"]
  993. >>> words = example["tokens"]
  994. >>> boxes = example["bboxes"]
  995. >>> encoding = processor(image, words, boxes=boxes, return_tensors="pt")
  996. >>> sequence_label = torch.tensor([1])
  997. >>> outputs = model(**encoding, labels=sequence_label)
  998. >>> loss = outputs.loss
  999. >>> logits = outputs.logits
  1000. ```"""
  1001. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1002. outputs = self.layoutlmv3(
  1003. input_ids,
  1004. attention_mask=attention_mask,
  1005. token_type_ids=token_type_ids,
  1006. position_ids=position_ids,
  1007. head_mask=head_mask,
  1008. inputs_embeds=inputs_embeds,
  1009. output_attentions=output_attentions,
  1010. output_hidden_states=output_hidden_states,
  1011. return_dict=return_dict,
  1012. bbox=bbox,
  1013. pixel_values=pixel_values,
  1014. )
  1015. sequence_output = outputs[0][:, 0, :]
  1016. logits = self.classifier(sequence_output)
  1017. loss = None
  1018. if labels is not None:
  1019. if self.config.problem_type is None:
  1020. if self.num_labels == 1:
  1021. self.config.problem_type = "regression"
  1022. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  1023. self.config.problem_type = "single_label_classification"
  1024. else:
  1025. self.config.problem_type = "multi_label_classification"
  1026. if self.config.problem_type == "regression":
  1027. loss_fct = MSELoss()
  1028. if self.num_labels == 1:
  1029. loss = loss_fct(logits.squeeze(), labels.squeeze())
  1030. else:
  1031. loss = loss_fct(logits, labels)
  1032. elif self.config.problem_type == "single_label_classification":
  1033. loss_fct = CrossEntropyLoss()
  1034. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  1035. elif self.config.problem_type == "multi_label_classification":
  1036. loss_fct = BCEWithLogitsLoss()
  1037. loss = loss_fct(logits, labels)
  1038. if not return_dict:
  1039. output = (logits,) + outputs[1:]
  1040. return ((loss,) + output) if loss is not None else output
  1041. return SequenceClassifierOutput(
  1042. loss=loss,
  1043. logits=logits,
  1044. hidden_states=outputs.hidden_states,
  1045. attentions=outputs.attentions,
  1046. )
  1047. __all__ = [
  1048. "LayoutLMv3ForQuestionAnswering",
  1049. "LayoutLMv3ForSequenceClassification",
  1050. "LayoutLMv3ForTokenClassification",
  1051. "LayoutLMv3Model",
  1052. "LayoutLMv3PreTrainedModel",
  1053. ]