modeling_layoutlmv2.py 60 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394
  1. # coding=utf-8
  2. # Copyright 2021 Microsoft Research The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch LayoutLMv2 model."""
  16. import math
  17. from typing import Optional, Union
  18. import torch
  19. from torch import nn
  20. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  21. from ...activations import ACT2FN
  22. from ...modeling_layers import GradientCheckpointingLayer
  23. from ...modeling_outputs import (
  24. BaseModelOutput,
  25. BaseModelOutputWithPooling,
  26. QuestionAnsweringModelOutput,
  27. SequenceClassifierOutput,
  28. TokenClassifierOutput,
  29. )
  30. from ...modeling_utils import PreTrainedModel
  31. from ...pytorch_utils import apply_chunking_to_forward
  32. from ...utils import auto_docstring, is_detectron2_available, logging, requires_backends
  33. from .configuration_layoutlmv2 import LayoutLMv2Config
  34. # soft dependency
  35. if is_detectron2_available():
  36. import detectron2
  37. from detectron2.modeling import META_ARCH_REGISTRY
  38. # This is needed as otherwise their overload will break sequential loading by overwriting buffer over and over. See
  39. # https://github.com/facebookresearch/detectron2/blob/9604f5995cc628619f0e4fd913453b4d7d61db3f/detectron2/layers/batch_norm.py#L83-L86
  40. detectron2.layers.batch_norm.FrozenBatchNorm2d._load_from_state_dict = torch.nn.Module._load_from_state_dict
  41. logger = logging.get_logger(__name__)
  42. class LayoutLMv2Embeddings(nn.Module):
  43. """Construct the embeddings from word, position and token_type embeddings."""
  44. def __init__(self, config):
  45. super().__init__()
  46. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  47. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
  48. self.x_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.coordinate_size)
  49. self.y_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.coordinate_size)
  50. self.h_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.shape_size)
  51. self.w_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.shape_size)
  52. self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
  53. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  54. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  55. self.register_buffer(
  56. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  57. )
  58. def _calc_spatial_position_embeddings(self, bbox):
  59. try:
  60. left_position_embeddings = self.x_position_embeddings(bbox[:, :, 0])
  61. upper_position_embeddings = self.y_position_embeddings(bbox[:, :, 1])
  62. right_position_embeddings = self.x_position_embeddings(bbox[:, :, 2])
  63. lower_position_embeddings = self.y_position_embeddings(bbox[:, :, 3])
  64. except IndexError as e:
  65. raise IndexError("The `bbox` coordinate values should be within 0-1000 range.") from e
  66. h_position_embeddings = self.h_position_embeddings(bbox[:, :, 3] - bbox[:, :, 1])
  67. w_position_embeddings = self.w_position_embeddings(bbox[:, :, 2] - bbox[:, :, 0])
  68. spatial_position_embeddings = torch.cat(
  69. [
  70. left_position_embeddings,
  71. upper_position_embeddings,
  72. right_position_embeddings,
  73. lower_position_embeddings,
  74. h_position_embeddings,
  75. w_position_embeddings,
  76. ],
  77. dim=-1,
  78. )
  79. return spatial_position_embeddings
  80. class LayoutLMv2SelfAttention(nn.Module):
  81. def __init__(self, config):
  82. super().__init__()
  83. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  84. raise ValueError(
  85. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  86. f"heads ({config.num_attention_heads})"
  87. )
  88. self.fast_qkv = config.fast_qkv
  89. self.num_attention_heads = config.num_attention_heads
  90. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  91. self.all_head_size = self.num_attention_heads * self.attention_head_size
  92. self.has_relative_attention_bias = config.has_relative_attention_bias
  93. self.has_spatial_attention_bias = config.has_spatial_attention_bias
  94. if config.fast_qkv:
  95. self.qkv_linear = nn.Linear(config.hidden_size, 3 * self.all_head_size, bias=False)
  96. self.q_bias = nn.Parameter(torch.zeros(1, 1, self.all_head_size))
  97. self.v_bias = nn.Parameter(torch.zeros(1, 1, self.all_head_size))
  98. else:
  99. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  100. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  101. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  102. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  103. def compute_qkv(self, hidden_states):
  104. if self.fast_qkv:
  105. qkv = self.qkv_linear(hidden_states)
  106. q, k, v = torch.chunk(qkv, 3, dim=-1)
  107. if q.ndimension() == self.q_bias.ndimension():
  108. q = q + self.q_bias
  109. v = v + self.v_bias
  110. else:
  111. _sz = (1,) * (q.ndimension() - 1) + (-1,)
  112. q = q + self.q_bias.view(*_sz)
  113. v = v + self.v_bias.view(*_sz)
  114. else:
  115. q = self.query(hidden_states)
  116. k = self.key(hidden_states)
  117. v = self.value(hidden_states)
  118. return q, k, v
  119. def forward(
  120. self,
  121. hidden_states,
  122. attention_mask=None,
  123. head_mask=None,
  124. output_attentions=False,
  125. rel_pos=None,
  126. rel_2d_pos=None,
  127. ):
  128. batch_size, seq_length, _ = hidden_states.shape
  129. query, key, value = self.compute_qkv(hidden_states)
  130. # (B, L, H*D) -> (B, H, L, D)
  131. query_layer = query.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2)
  132. key_layer = key.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2)
  133. value_layer = value.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2)
  134. query_layer = query_layer / math.sqrt(self.attention_head_size)
  135. # [BSZ, NAT, L, L]
  136. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  137. if self.has_relative_attention_bias:
  138. attention_scores += rel_pos
  139. if self.has_spatial_attention_bias:
  140. attention_scores += rel_2d_pos
  141. attention_scores = attention_scores.float().masked_fill_(
  142. attention_mask.to(torch.bool), torch.finfo(attention_scores.dtype).min
  143. )
  144. attention_probs = nn.functional.softmax(attention_scores, dim=-1, dtype=torch.float32).type_as(value_layer)
  145. # This is actually dropping out entire tokens to attend to, which might
  146. # seem a bit unusual, but is taken from the original Transformer paper.
  147. attention_probs = self.dropout(attention_probs)
  148. # Mask heads if we want to
  149. if head_mask is not None:
  150. attention_probs = attention_probs * head_mask
  151. context_layer = torch.matmul(attention_probs, value_layer)
  152. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  153. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  154. context_layer = context_layer.view(*new_context_layer_shape)
  155. outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
  156. return outputs
  157. class LayoutLMv2Attention(nn.Module):
  158. def __init__(self, config):
  159. super().__init__()
  160. self.self = LayoutLMv2SelfAttention(config)
  161. self.output = LayoutLMv2SelfOutput(config)
  162. def forward(
  163. self,
  164. hidden_states,
  165. attention_mask=None,
  166. head_mask=None,
  167. output_attentions=False,
  168. rel_pos=None,
  169. rel_2d_pos=None,
  170. ):
  171. self_outputs = self.self(
  172. hidden_states,
  173. attention_mask,
  174. head_mask,
  175. output_attentions,
  176. rel_pos=rel_pos,
  177. rel_2d_pos=rel_2d_pos,
  178. )
  179. attention_output = self.output(self_outputs[0], hidden_states)
  180. outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
  181. return outputs
  182. class LayoutLMv2SelfOutput(nn.Module):
  183. def __init__(self, config):
  184. super().__init__()
  185. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  186. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  187. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  188. def forward(self, hidden_states, input_tensor):
  189. hidden_states = self.dense(hidden_states)
  190. hidden_states = self.dropout(hidden_states)
  191. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  192. return hidden_states
  193. # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->LayoutLMv2
  194. class LayoutLMv2Intermediate(nn.Module):
  195. def __init__(self, config):
  196. super().__init__()
  197. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  198. if isinstance(config.hidden_act, str):
  199. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  200. else:
  201. self.intermediate_act_fn = config.hidden_act
  202. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  203. hidden_states = self.dense(hidden_states)
  204. hidden_states = self.intermediate_act_fn(hidden_states)
  205. return hidden_states
  206. # Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->LayoutLM
  207. class LayoutLMv2Output(nn.Module):
  208. def __init__(self, config):
  209. super().__init__()
  210. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  211. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  212. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  213. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  214. hidden_states = self.dense(hidden_states)
  215. hidden_states = self.dropout(hidden_states)
  216. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  217. return hidden_states
  218. class LayoutLMv2Layer(GradientCheckpointingLayer):
  219. def __init__(self, config):
  220. super().__init__()
  221. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  222. self.seq_len_dim = 1
  223. self.attention = LayoutLMv2Attention(config)
  224. self.intermediate = LayoutLMv2Intermediate(config)
  225. self.output = LayoutLMv2Output(config)
  226. def forward(
  227. self,
  228. hidden_states,
  229. attention_mask=None,
  230. head_mask=None,
  231. output_attentions=False,
  232. rel_pos=None,
  233. rel_2d_pos=None,
  234. ):
  235. self_attention_outputs = self.attention(
  236. hidden_states,
  237. attention_mask,
  238. head_mask,
  239. output_attentions=output_attentions,
  240. rel_pos=rel_pos,
  241. rel_2d_pos=rel_2d_pos,
  242. )
  243. attention_output = self_attention_outputs[0]
  244. outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
  245. layer_output = apply_chunking_to_forward(
  246. self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
  247. )
  248. outputs = (layer_output,) + outputs
  249. return outputs
  250. def feed_forward_chunk(self, attention_output):
  251. intermediate_output = self.intermediate(attention_output)
  252. layer_output = self.output(intermediate_output, attention_output)
  253. return layer_output
  254. def relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
  255. """
  256. Adapted from Mesh Tensorflow:
  257. https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
  258. Translate relative position to a bucket number for relative attention. The relative position is defined as
  259. memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
  260. position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for small
  261. absolute relative_position and larger buckets for larger absolute relative_positions. All relative positions
  262. >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. This should
  263. allow for more graceful generalization to longer sequences than the model has been trained on.
  264. Args:
  265. relative_position: an int32 Tensor
  266. bidirectional: a boolean - whether the attention is bidirectional
  267. num_buckets: an integer
  268. max_distance: an integer
  269. Returns:
  270. a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
  271. """
  272. ret = 0
  273. if bidirectional:
  274. num_buckets //= 2
  275. ret += (relative_position > 0).long() * num_buckets
  276. n = torch.abs(relative_position)
  277. else:
  278. n = torch.max(-relative_position, torch.zeros_like(relative_position))
  279. # now n is in the range [0, inf)
  280. # half of the buckets are for exact increments in positions
  281. max_exact = num_buckets // 2
  282. is_small = n < max_exact
  283. # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
  284. val_if_large = max_exact + (
  285. torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
  286. ).to(torch.long)
  287. val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
  288. ret += torch.where(is_small, n, val_if_large)
  289. return ret
  290. class LayoutLMv2Encoder(nn.Module):
  291. def __init__(self, config):
  292. super().__init__()
  293. self.config = config
  294. self.layer = nn.ModuleList([LayoutLMv2Layer(config) for _ in range(config.num_hidden_layers)])
  295. self.has_relative_attention_bias = config.has_relative_attention_bias
  296. self.has_spatial_attention_bias = config.has_spatial_attention_bias
  297. if self.has_relative_attention_bias:
  298. self.rel_pos_bins = config.rel_pos_bins
  299. self.max_rel_pos = config.max_rel_pos
  300. self.rel_pos_bias = nn.Linear(self.rel_pos_bins, config.num_attention_heads, bias=False)
  301. if self.has_spatial_attention_bias:
  302. self.max_rel_2d_pos = config.max_rel_2d_pos
  303. self.rel_2d_pos_bins = config.rel_2d_pos_bins
  304. self.rel_pos_x_bias = nn.Linear(self.rel_2d_pos_bins, config.num_attention_heads, bias=False)
  305. self.rel_pos_y_bias = nn.Linear(self.rel_2d_pos_bins, config.num_attention_heads, bias=False)
  306. self.gradient_checkpointing = False
  307. def _calculate_1d_position_embeddings(self, position_ids):
  308. rel_pos_mat = position_ids.unsqueeze(-2) - position_ids.unsqueeze(-1)
  309. rel_pos = relative_position_bucket(
  310. rel_pos_mat,
  311. num_buckets=self.rel_pos_bins,
  312. max_distance=self.max_rel_pos,
  313. )
  314. # Since this is a simple indexing operation that is independent of the input,
  315. # no need to track gradients for this operation
  316. #
  317. # Without this no_grad context, training speed slows down significantly
  318. with torch.no_grad():
  319. rel_pos = self.rel_pos_bias.weight.t()[rel_pos].permute(0, 3, 1, 2)
  320. rel_pos = rel_pos.contiguous()
  321. return rel_pos
  322. def _calculate_2d_position_embeddings(self, bbox):
  323. position_coord_x = bbox[:, :, 0]
  324. position_coord_y = bbox[:, :, 3]
  325. rel_pos_x_2d_mat = position_coord_x.unsqueeze(-2) - position_coord_x.unsqueeze(-1)
  326. rel_pos_y_2d_mat = position_coord_y.unsqueeze(-2) - position_coord_y.unsqueeze(-1)
  327. rel_pos_x = relative_position_bucket(
  328. rel_pos_x_2d_mat,
  329. num_buckets=self.rel_2d_pos_bins,
  330. max_distance=self.max_rel_2d_pos,
  331. )
  332. rel_pos_y = relative_position_bucket(
  333. rel_pos_y_2d_mat,
  334. num_buckets=self.rel_2d_pos_bins,
  335. max_distance=self.max_rel_2d_pos,
  336. )
  337. # Since this is a simple indexing operation that is independent of the input,
  338. # no need to track gradients for this operation
  339. #
  340. # Without this no_grad context, training speed slows down significantly
  341. with torch.no_grad():
  342. rel_pos_x = self.rel_pos_x_bias.weight.t()[rel_pos_x].permute(0, 3, 1, 2)
  343. rel_pos_y = self.rel_pos_y_bias.weight.t()[rel_pos_y].permute(0, 3, 1, 2)
  344. rel_pos_x = rel_pos_x.contiguous()
  345. rel_pos_y = rel_pos_y.contiguous()
  346. rel_2d_pos = rel_pos_x + rel_pos_y
  347. return rel_2d_pos
  348. def forward(
  349. self,
  350. hidden_states,
  351. attention_mask=None,
  352. head_mask=None,
  353. output_attentions=False,
  354. output_hidden_states=False,
  355. return_dict=True,
  356. bbox=None,
  357. position_ids=None,
  358. ):
  359. all_hidden_states = () if output_hidden_states else None
  360. all_self_attentions = () if output_attentions else None
  361. rel_pos = self._calculate_1d_position_embeddings(position_ids) if self.has_relative_attention_bias else None
  362. rel_2d_pos = self._calculate_2d_position_embeddings(bbox) if self.has_spatial_attention_bias else None
  363. for i, layer_module in enumerate(self.layer):
  364. if output_hidden_states:
  365. all_hidden_states = all_hidden_states + (hidden_states,)
  366. layer_head_mask = head_mask[i] if head_mask is not None else None
  367. layer_outputs = layer_module(
  368. hidden_states,
  369. attention_mask,
  370. layer_head_mask,
  371. output_attentions,
  372. rel_pos=rel_pos,
  373. rel_2d_pos=rel_2d_pos,
  374. )
  375. hidden_states = layer_outputs[0]
  376. if output_attentions:
  377. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  378. if output_hidden_states:
  379. all_hidden_states = all_hidden_states + (hidden_states,)
  380. if not return_dict:
  381. return tuple(
  382. v
  383. for v in [
  384. hidden_states,
  385. all_hidden_states,
  386. all_self_attentions,
  387. ]
  388. if v is not None
  389. )
  390. return BaseModelOutput(
  391. last_hidden_state=hidden_states,
  392. hidden_states=all_hidden_states,
  393. attentions=all_self_attentions,
  394. )
  395. @auto_docstring
  396. class LayoutLMv2PreTrainedModel(PreTrainedModel):
  397. config: LayoutLMv2Config
  398. base_model_prefix = "layoutlmv2"
  399. def _init_weights(self, module):
  400. """Initialize the weights"""
  401. if isinstance(module, nn.Linear):
  402. # Slightly different from the TF version which uses truncated_normal for initialization
  403. # cf https://github.com/pytorch/pytorch/pull/5617
  404. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  405. if module.bias is not None:
  406. module.bias.data.zero_()
  407. elif isinstance(module, nn.Embedding):
  408. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  409. if module.padding_idx is not None:
  410. module.weight.data[module.padding_idx].zero_()
  411. elif isinstance(module, nn.LayerNorm):
  412. module.bias.data.zero_()
  413. module.weight.data.fill_(1.0)
  414. elif isinstance(module, LayoutLMv2SelfAttention):
  415. if self.config.fast_qkv:
  416. module.q_bias.data.zero_()
  417. module.v_bias.data.zero_()
  418. elif isinstance(module, LayoutLMv2Model):
  419. if hasattr(module, "visual_segment_embedding"):
  420. module.visual_segment_embedding.data.normal_(mean=0.0, std=self.config.initializer_range)
  421. def my_convert_sync_batchnorm(module, process_group=None):
  422. # same as `nn.modules.SyncBatchNorm.convert_sync_batchnorm` but allowing converting from `detectron2.layers.FrozenBatchNorm2d`
  423. if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
  424. return nn.modules.SyncBatchNorm.convert_sync_batchnorm(module, process_group)
  425. module_output = module
  426. if isinstance(module, detectron2.layers.FrozenBatchNorm2d):
  427. module_output = torch.nn.SyncBatchNorm(
  428. num_features=module.num_features,
  429. eps=module.eps,
  430. affine=True,
  431. track_running_stats=True,
  432. process_group=process_group,
  433. )
  434. module_output.weight = torch.nn.Parameter(module.weight)
  435. module_output.bias = torch.nn.Parameter(module.bias)
  436. module_output.running_mean = module.running_mean
  437. module_output.running_var = module.running_var
  438. module_output.num_batches_tracked = torch.tensor(0, dtype=torch.long, device=module.running_mean.device)
  439. for name, child in module.named_children():
  440. module_output.add_module(name, my_convert_sync_batchnorm(child, process_group))
  441. del module
  442. return module_output
  443. class LayoutLMv2VisualBackbone(nn.Module):
  444. def __init__(self, config):
  445. super().__init__()
  446. self.cfg = config.get_detectron2_config()
  447. meta_arch = self.cfg.MODEL.META_ARCHITECTURE
  448. model = META_ARCH_REGISTRY.get(meta_arch)(self.cfg)
  449. assert isinstance(model.backbone, detectron2.modeling.backbone.FPN)
  450. self.backbone = model.backbone
  451. assert len(self.cfg.MODEL.PIXEL_MEAN) == len(self.cfg.MODEL.PIXEL_STD)
  452. num_channels = len(self.cfg.MODEL.PIXEL_MEAN)
  453. self.register_buffer(
  454. "pixel_mean",
  455. torch.Tensor(self.cfg.MODEL.PIXEL_MEAN).view(num_channels, 1, 1),
  456. persistent=False,
  457. )
  458. self.register_buffer(
  459. "pixel_std", torch.Tensor(self.cfg.MODEL.PIXEL_STD).view(num_channels, 1, 1), persistent=False
  460. )
  461. self.out_feature_key = "p2"
  462. if torch.are_deterministic_algorithms_enabled():
  463. logger.warning("using `AvgPool2d` instead of `AdaptiveAvgPool2d`")
  464. input_shape = (224, 224)
  465. backbone_stride = self.backbone.output_shape()[self.out_feature_key].stride
  466. self.pool = nn.AvgPool2d(
  467. (
  468. math.ceil(math.ceil(input_shape[0] / backbone_stride) / config.image_feature_pool_shape[0]),
  469. math.ceil(math.ceil(input_shape[1] / backbone_stride) / config.image_feature_pool_shape[1]),
  470. )
  471. )
  472. else:
  473. self.pool = nn.AdaptiveAvgPool2d(config.image_feature_pool_shape[:2])
  474. if len(config.image_feature_pool_shape) == 2:
  475. config.image_feature_pool_shape.append(self.backbone.output_shape()[self.out_feature_key].channels)
  476. assert self.backbone.output_shape()[self.out_feature_key].channels == config.image_feature_pool_shape[2]
  477. def forward(self, images):
  478. images_input = ((images if torch.is_tensor(images) else images.tensor) - self.pixel_mean) / self.pixel_std
  479. features = self.backbone(images_input)
  480. features = features[self.out_feature_key]
  481. features = self.pool(features).flatten(start_dim=2).transpose(1, 2).contiguous()
  482. return features
  483. def synchronize_batch_norm(self):
  484. if not (
  485. torch.distributed.is_available()
  486. and torch.distributed.is_initialized()
  487. and torch.distributed.get_rank() > -1
  488. ):
  489. raise RuntimeError("Make sure torch.distributed is set up properly.")
  490. self_rank = torch.distributed.get_rank()
  491. node_size = torch.cuda.device_count()
  492. world_size = torch.distributed.get_world_size()
  493. if not (world_size % node_size == 0):
  494. raise RuntimeError("Make sure the number of processes can be divided by the number of nodes")
  495. node_global_ranks = [list(range(i * node_size, (i + 1) * node_size)) for i in range(world_size // node_size)]
  496. sync_bn_groups = [
  497. torch.distributed.new_group(ranks=node_global_ranks[i]) for i in range(world_size // node_size)
  498. ]
  499. node_rank = self_rank // node_size
  500. self.backbone = my_convert_sync_batchnorm(self.backbone, process_group=sync_bn_groups[node_rank])
  501. class LayoutLMv2Pooler(nn.Module):
  502. def __init__(self, config):
  503. super().__init__()
  504. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  505. self.activation = nn.Tanh()
  506. def forward(self, hidden_states):
  507. # We "pool" the model by simply taking the hidden state corresponding
  508. # to the first token.
  509. first_token_tensor = hidden_states[:, 0]
  510. pooled_output = self.dense(first_token_tensor)
  511. pooled_output = self.activation(pooled_output)
  512. return pooled_output
  513. @auto_docstring
  514. class LayoutLMv2Model(LayoutLMv2PreTrainedModel):
  515. def __init__(self, config):
  516. requires_backends(self, "detectron2")
  517. super().__init__(config)
  518. self.config = config
  519. self.has_visual_segment_embedding = config.has_visual_segment_embedding
  520. self.embeddings = LayoutLMv2Embeddings(config)
  521. self.visual = LayoutLMv2VisualBackbone(config)
  522. self.visual_proj = nn.Linear(config.image_feature_pool_shape[-1], config.hidden_size)
  523. if self.has_visual_segment_embedding:
  524. self.visual_segment_embedding = nn.Parameter(nn.Embedding(1, config.hidden_size).weight[0])
  525. self.visual_LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  526. self.visual_dropout = nn.Dropout(config.hidden_dropout_prob)
  527. self.encoder = LayoutLMv2Encoder(config)
  528. self.pooler = LayoutLMv2Pooler(config)
  529. # Initialize weights and apply final processing
  530. self.post_init()
  531. def get_input_embeddings(self):
  532. return self.embeddings.word_embeddings
  533. def set_input_embeddings(self, value):
  534. self.embeddings.word_embeddings = value
  535. def _calc_text_embeddings(self, input_ids, bbox, position_ids, token_type_ids, inputs_embeds=None):
  536. if input_ids is not None:
  537. input_shape = input_ids.size()
  538. else:
  539. input_shape = inputs_embeds.size()[:-1]
  540. seq_length = input_shape[1]
  541. if position_ids is None:
  542. position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
  543. position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
  544. if token_type_ids is None:
  545. token_type_ids = torch.zeros_like(input_ids)
  546. if inputs_embeds is None:
  547. inputs_embeds = self.embeddings.word_embeddings(input_ids)
  548. position_embeddings = self.embeddings.position_embeddings(position_ids)
  549. spatial_position_embeddings = self.embeddings._calc_spatial_position_embeddings(bbox)
  550. token_type_embeddings = self.embeddings.token_type_embeddings(token_type_ids)
  551. embeddings = inputs_embeds + position_embeddings + spatial_position_embeddings + token_type_embeddings
  552. embeddings = self.embeddings.LayerNorm(embeddings)
  553. embeddings = self.embeddings.dropout(embeddings)
  554. return embeddings
  555. def _calc_img_embeddings(self, image, bbox, position_ids):
  556. visual_embeddings = self.visual_proj(self.visual(image))
  557. position_embeddings = self.embeddings.position_embeddings(position_ids)
  558. spatial_position_embeddings = self.embeddings._calc_spatial_position_embeddings(bbox)
  559. embeddings = visual_embeddings + position_embeddings + spatial_position_embeddings
  560. if self.has_visual_segment_embedding:
  561. embeddings += self.visual_segment_embedding
  562. embeddings = self.visual_LayerNorm(embeddings)
  563. embeddings = self.visual_dropout(embeddings)
  564. return embeddings
  565. def _calc_visual_bbox(self, image_feature_pool_shape, bbox, device, final_shape):
  566. visual_bbox_x = torch.div(
  567. torch.arange(
  568. 0,
  569. 1000 * (image_feature_pool_shape[1] + 1),
  570. 1000,
  571. device=device,
  572. dtype=bbox.dtype,
  573. ),
  574. self.config.image_feature_pool_shape[1],
  575. rounding_mode="floor",
  576. )
  577. visual_bbox_y = torch.div(
  578. torch.arange(
  579. 0,
  580. 1000 * (self.config.image_feature_pool_shape[0] + 1),
  581. 1000,
  582. device=device,
  583. dtype=bbox.dtype,
  584. ),
  585. self.config.image_feature_pool_shape[0],
  586. rounding_mode="floor",
  587. )
  588. visual_bbox = torch.stack(
  589. [
  590. visual_bbox_x[:-1].repeat(image_feature_pool_shape[0], 1),
  591. visual_bbox_y[:-1].repeat(image_feature_pool_shape[1], 1).transpose(0, 1),
  592. visual_bbox_x[1:].repeat(image_feature_pool_shape[0], 1),
  593. visual_bbox_y[1:].repeat(image_feature_pool_shape[1], 1).transpose(0, 1),
  594. ],
  595. dim=-1,
  596. ).view(-1, bbox.size(-1))
  597. visual_bbox = visual_bbox.repeat(final_shape[0], 1, 1)
  598. return visual_bbox
  599. def _get_input_shape(self, input_ids=None, inputs_embeds=None):
  600. if input_ids is not None and inputs_embeds is not None:
  601. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  602. elif input_ids is not None:
  603. return input_ids.size()
  604. elif inputs_embeds is not None:
  605. return inputs_embeds.size()[:-1]
  606. else:
  607. raise ValueError("You have to specify either input_ids or inputs_embeds")
  608. @auto_docstring
  609. def forward(
  610. self,
  611. input_ids: Optional[torch.LongTensor] = None,
  612. bbox: Optional[torch.LongTensor] = None,
  613. image: Optional[torch.FloatTensor] = None,
  614. attention_mask: Optional[torch.FloatTensor] = None,
  615. token_type_ids: Optional[torch.LongTensor] = None,
  616. position_ids: Optional[torch.LongTensor] = None,
  617. head_mask: Optional[torch.FloatTensor] = None,
  618. inputs_embeds: Optional[torch.FloatTensor] = None,
  619. output_attentions: Optional[bool] = None,
  620. output_hidden_states: Optional[bool] = None,
  621. return_dict: Optional[bool] = None,
  622. ) -> Union[tuple, BaseModelOutputWithPooling]:
  623. r"""
  624. bbox (`torch.LongTensor` of shape `((batch_size, sequence_length), 4)`, *optional*):
  625. Bounding boxes of each input sequence tokens. Selected in the range `[0,
  626. config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1)
  627. format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1,
  628. y1) represents the position of the lower right corner.
  629. image (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `detectron.structures.ImageList` whose `tensors` is of shape `(batch_size, num_channels, height, width)`):
  630. Batch of document images.
  631. Examples:
  632. ```python
  633. >>> from transformers import AutoProcessor, LayoutLMv2Model, set_seed
  634. >>> from PIL import Image
  635. >>> import torch
  636. >>> from datasets import load_dataset
  637. >>> set_seed(0)
  638. >>> processor = AutoProcessor.from_pretrained("microsoft/layoutlmv2-base-uncased")
  639. >>> model = LayoutLMv2Model.from_pretrained("microsoft/layoutlmv2-base-uncased")
  640. >>> dataset = load_dataset("hf-internal-testing/fixtures_docvqa")
  641. >>> image = dataset["test"][0]["image"]
  642. >>> encoding = processor(image, return_tensors="pt")
  643. >>> outputs = model(**encoding)
  644. >>> last_hidden_states = outputs.last_hidden_state
  645. >>> last_hidden_states.shape
  646. torch.Size([1, 342, 768])
  647. ```
  648. """
  649. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  650. output_hidden_states = (
  651. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  652. )
  653. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  654. input_shape = self._get_input_shape(input_ids, inputs_embeds)
  655. device = input_ids.device if input_ids is not None else inputs_embeds.device
  656. visual_shape = list(input_shape)
  657. visual_shape[1] = self.config.image_feature_pool_shape[0] * self.config.image_feature_pool_shape[1]
  658. visual_shape = torch.Size(visual_shape)
  659. # needs a new copy of input_shape for tracing. Otherwise wrong dimensions will occur
  660. final_shape = list(self._get_input_shape(input_ids, inputs_embeds))
  661. final_shape[1] += visual_shape[1]
  662. final_shape = torch.Size(final_shape)
  663. visual_bbox = self._calc_visual_bbox(self.config.image_feature_pool_shape, bbox, device, final_shape)
  664. final_bbox = torch.cat([bbox, visual_bbox], dim=1)
  665. if attention_mask is None:
  666. attention_mask = torch.ones(input_shape, device=device)
  667. visual_attention_mask = torch.ones(visual_shape, device=device)
  668. final_attention_mask = torch.cat([attention_mask, visual_attention_mask], dim=1)
  669. if token_type_ids is None:
  670. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
  671. if position_ids is None:
  672. seq_length = input_shape[1]
  673. position_ids = self.embeddings.position_ids[:, :seq_length]
  674. position_ids = position_ids.expand(input_shape)
  675. visual_position_ids = torch.arange(0, visual_shape[1], dtype=torch.long, device=device).repeat(
  676. input_shape[0], 1
  677. )
  678. final_position_ids = torch.cat([position_ids, visual_position_ids], dim=1)
  679. if bbox is None:
  680. bbox = torch.zeros(tuple(list(input_shape) + [4]), dtype=torch.long, device=device)
  681. text_layout_emb = self._calc_text_embeddings(
  682. input_ids=input_ids,
  683. bbox=bbox,
  684. token_type_ids=token_type_ids,
  685. position_ids=position_ids,
  686. inputs_embeds=inputs_embeds,
  687. )
  688. visual_emb = self._calc_img_embeddings(
  689. image=image,
  690. bbox=visual_bbox,
  691. position_ids=visual_position_ids,
  692. )
  693. final_emb = torch.cat([text_layout_emb, visual_emb], dim=1)
  694. extended_attention_mask = final_attention_mask.unsqueeze(1).unsqueeze(2)
  695. extended_attention_mask = extended_attention_mask.to(dtype=self.dtype)
  696. extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(self.dtype).min
  697. if head_mask is not None:
  698. if head_mask.dim() == 1:
  699. head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
  700. head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
  701. elif head_mask.dim() == 2:
  702. head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
  703. head_mask = head_mask.to(dtype=next(self.parameters()).dtype)
  704. else:
  705. head_mask = [None] * self.config.num_hidden_layers
  706. encoder_outputs = self.encoder(
  707. final_emb,
  708. extended_attention_mask,
  709. bbox=final_bbox,
  710. position_ids=final_position_ids,
  711. head_mask=head_mask,
  712. output_attentions=output_attentions,
  713. output_hidden_states=output_hidden_states,
  714. return_dict=return_dict,
  715. )
  716. sequence_output = encoder_outputs[0]
  717. pooled_output = self.pooler(sequence_output)
  718. if not return_dict:
  719. return (sequence_output, pooled_output) + encoder_outputs[1:]
  720. return BaseModelOutputWithPooling(
  721. last_hidden_state=sequence_output,
  722. pooler_output=pooled_output,
  723. hidden_states=encoder_outputs.hidden_states,
  724. attentions=encoder_outputs.attentions,
  725. )
  726. @auto_docstring(
  727. custom_intro="""
  728. LayoutLMv2 Model with a sequence classification head on top (a linear layer on top of the concatenation of the
  729. final hidden state of the [CLS] token, average-pooled initial visual embeddings and average-pooled final visual
  730. embeddings, e.g. for document image classification tasks such as the
  731. [RVL-CDIP](https://www.cs.cmu.edu/~aharley/rvl-cdip/) dataset.
  732. """
  733. )
  734. class LayoutLMv2ForSequenceClassification(LayoutLMv2PreTrainedModel):
  735. def __init__(self, config):
  736. super().__init__(config)
  737. self.num_labels = config.num_labels
  738. self.layoutlmv2 = LayoutLMv2Model(config)
  739. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  740. self.classifier = nn.Linear(config.hidden_size * 3, config.num_labels)
  741. # Initialize weights and apply final processing
  742. self.post_init()
  743. def get_input_embeddings(self):
  744. return self.layoutlmv2.embeddings.word_embeddings
  745. @auto_docstring
  746. def forward(
  747. self,
  748. input_ids: Optional[torch.LongTensor] = None,
  749. bbox: Optional[torch.LongTensor] = None,
  750. image: Optional[torch.FloatTensor] = None,
  751. attention_mask: Optional[torch.FloatTensor] = None,
  752. token_type_ids: Optional[torch.LongTensor] = None,
  753. position_ids: Optional[torch.LongTensor] = None,
  754. head_mask: Optional[torch.FloatTensor] = None,
  755. inputs_embeds: Optional[torch.FloatTensor] = None,
  756. labels: Optional[torch.LongTensor] = None,
  757. output_attentions: Optional[bool] = None,
  758. output_hidden_states: Optional[bool] = None,
  759. return_dict: Optional[bool] = None,
  760. ) -> Union[tuple, SequenceClassifierOutput]:
  761. r"""
  762. input_ids (`torch.LongTensor` of shape `batch_size, sequence_length`):
  763. Indices of input sequence tokens in the vocabulary.
  764. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  765. [`PreTrainedTokenizer.__call__`] for details.
  766. [What are input IDs?](../glossary#input-ids)
  767. bbox (`torch.LongTensor` of shape `(batch_size, sequence_length, 4)`, *optional*):
  768. Bounding boxes of each input sequence tokens. Selected in the range `[0,
  769. config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1)
  770. format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1,
  771. y1) represents the position of the lower right corner.
  772. image (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `detectron.structures.ImageList` whose `tensors` is of shape `(batch_size, num_channels, height, width)`):
  773. Batch of document images.
  774. token_type_ids (`torch.LongTensor` of shape `batch_size, sequence_length`, *optional*):
  775. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
  776. 1]`:
  777. - 0 corresponds to a *sentence A* token,
  778. - 1 corresponds to a *sentence B* token.
  779. [What are token type IDs?](../glossary#token-type-ids)
  780. position_ids (`torch.LongTensor` of shape `batch_size, sequence_length`, *optional*):
  781. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  782. config.max_position_embeddings - 1]`.
  783. [What are position IDs?](../glossary#position-ids)
  784. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  785. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  786. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  787. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  788. Example:
  789. ```python
  790. >>> from transformers import AutoProcessor, LayoutLMv2ForSequenceClassification, set_seed
  791. >>> from PIL import Image
  792. >>> import torch
  793. >>> from datasets import load_dataset
  794. >>> set_seed(0)
  795. >>> dataset = load_dataset("aharley/rvl_cdip", split="train", streaming=True)
  796. >>> data = next(iter(dataset))
  797. >>> image = data["image"].convert("RGB")
  798. >>> processor = AutoProcessor.from_pretrained("microsoft/layoutlmv2-base-uncased")
  799. >>> model = LayoutLMv2ForSequenceClassification.from_pretrained(
  800. ... "microsoft/layoutlmv2-base-uncased", num_labels=dataset.info.features["label"].num_classes
  801. ... )
  802. >>> encoding = processor(image, return_tensors="pt")
  803. >>> sequence_label = torch.tensor([data["label"]])
  804. >>> outputs = model(**encoding, labels=sequence_label)
  805. >>> loss, logits = outputs.loss, outputs.logits
  806. >>> predicted_idx = logits.argmax(dim=-1).item()
  807. >>> predicted_answer = dataset.info.features["label"].names[4]
  808. >>> predicted_idx, predicted_answer # results are not good without further fine-tuning
  809. (7, 'advertisement')
  810. ```
  811. """
  812. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  813. if input_ids is not None and inputs_embeds is not None:
  814. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  815. elif input_ids is not None:
  816. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  817. input_shape = input_ids.size()
  818. elif inputs_embeds is not None:
  819. input_shape = inputs_embeds.size()[:-1]
  820. else:
  821. raise ValueError("You have to specify either input_ids or inputs_embeds")
  822. device = input_ids.device if input_ids is not None else inputs_embeds.device
  823. visual_shape = list(input_shape)
  824. visual_shape[1] = self.config.image_feature_pool_shape[0] * self.config.image_feature_pool_shape[1]
  825. visual_shape = torch.Size(visual_shape)
  826. final_shape = list(input_shape)
  827. final_shape[1] += visual_shape[1]
  828. final_shape = torch.Size(final_shape)
  829. visual_bbox = self.layoutlmv2._calc_visual_bbox(
  830. self.config.image_feature_pool_shape, bbox, device, final_shape
  831. )
  832. visual_position_ids = torch.arange(0, visual_shape[1], dtype=torch.long, device=device).repeat(
  833. input_shape[0], 1
  834. )
  835. initial_image_embeddings = self.layoutlmv2._calc_img_embeddings(
  836. image=image,
  837. bbox=visual_bbox,
  838. position_ids=visual_position_ids,
  839. )
  840. outputs = self.layoutlmv2(
  841. input_ids=input_ids,
  842. bbox=bbox,
  843. image=image,
  844. attention_mask=attention_mask,
  845. token_type_ids=token_type_ids,
  846. position_ids=position_ids,
  847. head_mask=head_mask,
  848. inputs_embeds=inputs_embeds,
  849. output_attentions=output_attentions,
  850. output_hidden_states=output_hidden_states,
  851. return_dict=return_dict,
  852. )
  853. if input_ids is not None:
  854. input_shape = input_ids.size()
  855. else:
  856. input_shape = inputs_embeds.size()[:-1]
  857. seq_length = input_shape[1]
  858. sequence_output, final_image_embeddings = outputs[0][:, :seq_length], outputs[0][:, seq_length:]
  859. cls_final_output = sequence_output[:, 0, :]
  860. # average-pool the visual embeddings
  861. pooled_initial_image_embeddings = initial_image_embeddings.mean(dim=1)
  862. pooled_final_image_embeddings = final_image_embeddings.mean(dim=1)
  863. # concatenate with cls_final_output
  864. sequence_output = torch.cat(
  865. [cls_final_output, pooled_initial_image_embeddings, pooled_final_image_embeddings], dim=1
  866. )
  867. sequence_output = self.dropout(sequence_output)
  868. logits = self.classifier(sequence_output)
  869. loss = None
  870. if labels is not None:
  871. if self.config.problem_type is None:
  872. if self.num_labels == 1:
  873. self.config.problem_type = "regression"
  874. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  875. self.config.problem_type = "single_label_classification"
  876. else:
  877. self.config.problem_type = "multi_label_classification"
  878. if self.config.problem_type == "regression":
  879. loss_fct = MSELoss()
  880. if self.num_labels == 1:
  881. loss = loss_fct(logits.squeeze(), labels.squeeze())
  882. else:
  883. loss = loss_fct(logits, labels)
  884. elif self.config.problem_type == "single_label_classification":
  885. loss_fct = CrossEntropyLoss()
  886. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  887. elif self.config.problem_type == "multi_label_classification":
  888. loss_fct = BCEWithLogitsLoss()
  889. loss = loss_fct(logits, labels)
  890. if not return_dict:
  891. output = (logits,) + outputs[2:]
  892. return ((loss,) + output) if loss is not None else output
  893. return SequenceClassifierOutput(
  894. loss=loss,
  895. logits=logits,
  896. hidden_states=outputs.hidden_states,
  897. attentions=outputs.attentions,
  898. )
  899. @auto_docstring(
  900. custom_intro="""
  901. LayoutLMv2 Model with a token classification head on top (a linear layer on top of the text part of the hidden
  902. states) e.g. for sequence labeling (information extraction) tasks such as
  903. [FUNSD](https://guillaumejaume.github.io/FUNSD/), [SROIE](https://rrc.cvc.uab.es/?ch=13),
  904. [CORD](https://github.com/clovaai/cord) and [Kleister-NDA](https://github.com/applicaai/kleister-nda).
  905. """
  906. )
  907. class LayoutLMv2ForTokenClassification(LayoutLMv2PreTrainedModel):
  908. def __init__(self, config):
  909. super().__init__(config)
  910. self.num_labels = config.num_labels
  911. self.layoutlmv2 = LayoutLMv2Model(config)
  912. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  913. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  914. # Initialize weights and apply final processing
  915. self.post_init()
  916. def get_input_embeddings(self):
  917. return self.layoutlmv2.embeddings.word_embeddings
  918. @auto_docstring
  919. def forward(
  920. self,
  921. input_ids: Optional[torch.LongTensor] = None,
  922. bbox: Optional[torch.LongTensor] = None,
  923. image: Optional[torch.FloatTensor] = None,
  924. attention_mask: Optional[torch.FloatTensor] = None,
  925. token_type_ids: Optional[torch.LongTensor] = None,
  926. position_ids: Optional[torch.LongTensor] = None,
  927. head_mask: Optional[torch.FloatTensor] = None,
  928. inputs_embeds: Optional[torch.FloatTensor] = None,
  929. labels: Optional[torch.LongTensor] = None,
  930. output_attentions: Optional[bool] = None,
  931. output_hidden_states: Optional[bool] = None,
  932. return_dict: Optional[bool] = None,
  933. ) -> Union[tuple, TokenClassifierOutput]:
  934. r"""
  935. input_ids (`torch.LongTensor` of shape `batch_size, sequence_length`):
  936. Indices of input sequence tokens in the vocabulary.
  937. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  938. [`PreTrainedTokenizer.__call__`] for details.
  939. [What are input IDs?](../glossary#input-ids)
  940. bbox (`torch.LongTensor` of shape `(batch_size, sequence_length, 4)`, *optional*):
  941. Bounding boxes of each input sequence tokens. Selected in the range `[0,
  942. config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1)
  943. format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1,
  944. y1) represents the position of the lower right corner.
  945. image (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `detectron.structures.ImageList` whose `tensors` is of shape `(batch_size, num_channels, height, width)`):
  946. Batch of document images.
  947. token_type_ids (`torch.LongTensor` of shape `batch_size, sequence_length`, *optional*):
  948. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
  949. 1]`:
  950. - 0 corresponds to a *sentence A* token,
  951. - 1 corresponds to a *sentence B* token.
  952. [What are token type IDs?](../glossary#token-type-ids)
  953. position_ids (`torch.LongTensor` of shape `batch_size, sequence_length`, *optional*):
  954. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  955. config.max_position_embeddings - 1]`.
  956. [What are position IDs?](../glossary#position-ids)
  957. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  958. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  959. Example:
  960. ```python
  961. >>> from transformers import AutoProcessor, LayoutLMv2ForTokenClassification, set_seed
  962. >>> from PIL import Image
  963. >>> from datasets import load_dataset
  964. >>> set_seed(0)
  965. >>> datasets = load_dataset("nielsr/funsd", split="test")
  966. >>> labels = datasets.features["ner_tags"].feature.names
  967. >>> id2label = {v: k for v, k in enumerate(labels)}
  968. >>> processor = AutoProcessor.from_pretrained("microsoft/layoutlmv2-base-uncased", revision="no_ocr")
  969. >>> model = LayoutLMv2ForTokenClassification.from_pretrained(
  970. ... "microsoft/layoutlmv2-base-uncased", num_labels=len(labels)
  971. ... )
  972. >>> data = datasets[0]
  973. >>> image = Image.open(data["image_path"]).convert("RGB")
  974. >>> words = data["words"]
  975. >>> boxes = data["bboxes"] # make sure to normalize your bounding boxes
  976. >>> word_labels = data["ner_tags"]
  977. >>> encoding = processor(
  978. ... image,
  979. ... words,
  980. ... boxes=boxes,
  981. ... word_labels=word_labels,
  982. ... padding="max_length",
  983. ... truncation=True,
  984. ... return_tensors="pt",
  985. ... )
  986. >>> outputs = model(**encoding)
  987. >>> logits, loss = outputs.logits, outputs.loss
  988. >>> predicted_token_class_ids = logits.argmax(-1)
  989. >>> predicted_tokens_classes = [id2label[t.item()] for t in predicted_token_class_ids[0]]
  990. >>> predicted_tokens_classes[:5] # results are not good without further fine-tuning
  991. ['I-HEADER', 'I-HEADER', 'I-QUESTION', 'I-HEADER', 'I-QUESTION']
  992. ```
  993. """
  994. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  995. outputs = self.layoutlmv2(
  996. input_ids=input_ids,
  997. bbox=bbox,
  998. image=image,
  999. attention_mask=attention_mask,
  1000. token_type_ids=token_type_ids,
  1001. position_ids=position_ids,
  1002. head_mask=head_mask,
  1003. inputs_embeds=inputs_embeds,
  1004. output_attentions=output_attentions,
  1005. output_hidden_states=output_hidden_states,
  1006. return_dict=return_dict,
  1007. )
  1008. if input_ids is not None:
  1009. input_shape = input_ids.size()
  1010. else:
  1011. input_shape = inputs_embeds.size()[:-1]
  1012. seq_length = input_shape[1]
  1013. # only take the text part of the output representations
  1014. sequence_output = outputs[0][:, :seq_length]
  1015. sequence_output = self.dropout(sequence_output)
  1016. logits = self.classifier(sequence_output)
  1017. loss = None
  1018. if labels is not None:
  1019. loss_fct = CrossEntropyLoss()
  1020. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  1021. if not return_dict:
  1022. output = (logits,) + outputs[2:]
  1023. return ((loss,) + output) if loss is not None else output
  1024. return TokenClassifierOutput(
  1025. loss=loss,
  1026. logits=logits,
  1027. hidden_states=outputs.hidden_states,
  1028. attentions=outputs.attentions,
  1029. )
  1030. @auto_docstring
  1031. class LayoutLMv2ForQuestionAnswering(LayoutLMv2PreTrainedModel):
  1032. def __init__(self, config, has_visual_segment_embedding=True):
  1033. r"""
  1034. has_visual_segment_embedding (`bool`, *optional*, defaults to `True`):
  1035. Whether or not to add visual segment embeddings.
  1036. """
  1037. super().__init__(config)
  1038. self.num_labels = config.num_labels
  1039. config.has_visual_segment_embedding = has_visual_segment_embedding
  1040. self.layoutlmv2 = LayoutLMv2Model(config)
  1041. self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
  1042. # Initialize weights and apply final processing
  1043. self.post_init()
  1044. def get_input_embeddings(self):
  1045. return self.layoutlmv2.embeddings.word_embeddings
  1046. @auto_docstring
  1047. def forward(
  1048. self,
  1049. input_ids: Optional[torch.LongTensor] = None,
  1050. bbox: Optional[torch.LongTensor] = None,
  1051. image: Optional[torch.FloatTensor] = None,
  1052. attention_mask: Optional[torch.FloatTensor] = None,
  1053. token_type_ids: Optional[torch.LongTensor] = None,
  1054. position_ids: Optional[torch.LongTensor] = None,
  1055. head_mask: Optional[torch.FloatTensor] = None,
  1056. inputs_embeds: Optional[torch.FloatTensor] = None,
  1057. start_positions: Optional[torch.LongTensor] = None,
  1058. end_positions: Optional[torch.LongTensor] = None,
  1059. output_attentions: Optional[bool] = None,
  1060. output_hidden_states: Optional[bool] = None,
  1061. return_dict: Optional[bool] = None,
  1062. ) -> Union[tuple, QuestionAnsweringModelOutput]:
  1063. r"""
  1064. input_ids (`torch.LongTensor` of shape `batch_size, sequence_length`):
  1065. Indices of input sequence tokens in the vocabulary.
  1066. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1067. [`PreTrainedTokenizer.__call__`] for details.
  1068. [What are input IDs?](../glossary#input-ids)
  1069. bbox (`torch.LongTensor` of shape `(batch_size, sequence_length, 4)`, *optional*):
  1070. Bounding boxes of each input sequence tokens. Selected in the range `[0,
  1071. config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1)
  1072. format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1,
  1073. y1) represents the position of the lower right corner.
  1074. image (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `detectron.structures.ImageList` whose `tensors` is of shape `(batch_size, num_channels, height, width)`):
  1075. Batch of document images.
  1076. token_type_ids (`torch.LongTensor` of shape `batch_size, sequence_length`, *optional*):
  1077. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
  1078. 1]`:
  1079. - 0 corresponds to a *sentence A* token,
  1080. - 1 corresponds to a *sentence B* token.
  1081. [What are token type IDs?](../glossary#token-type-ids)
  1082. position_ids (`torch.LongTensor` of shape `batch_size, sequence_length`, *optional*):
  1083. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  1084. config.max_position_embeddings - 1]`.
  1085. [What are position IDs?](../glossary#position-ids)
  1086. Example:
  1087. In this example below, we give the LayoutLMv2 model an image (of texts) and ask it a question. It will give us
  1088. a prediction of what it thinks the answer is (the span of the answer within the texts parsed from the image).
  1089. ```python
  1090. >>> from transformers import AutoProcessor, LayoutLMv2ForQuestionAnswering, set_seed
  1091. >>> import torch
  1092. >>> from PIL import Image
  1093. >>> from datasets import load_dataset
  1094. >>> set_seed(0)
  1095. >>> processor = AutoProcessor.from_pretrained("microsoft/layoutlmv2-base-uncased")
  1096. >>> model = LayoutLMv2ForQuestionAnswering.from_pretrained("microsoft/layoutlmv2-base-uncased")
  1097. >>> dataset = load_dataset("hf-internal-testing/fixtures_docvqa")
  1098. >>> image = dataset["test"][0]["image"]
  1099. >>> question = "When is coffee break?"
  1100. >>> encoding = processor(image, question, return_tensors="pt")
  1101. >>> outputs = model(**encoding)
  1102. >>> predicted_start_idx = outputs.start_logits.argmax(-1).item()
  1103. >>> predicted_end_idx = outputs.end_logits.argmax(-1).item()
  1104. >>> predicted_start_idx, predicted_end_idx
  1105. (30, 191)
  1106. >>> predicted_answer_tokens = encoding.input_ids.squeeze()[predicted_start_idx : predicted_end_idx + 1]
  1107. >>> predicted_answer = processor.tokenizer.decode(predicted_answer_tokens)
  1108. >>> predicted_answer # results are not good without further fine-tuning
  1109. '44 a. m. to 12 : 25 p. m. 12 : 25 to 12 : 58 p. m. 12 : 58 to 4 : 00 p. m. 2 : 00 to 5 : 00 p. m. coffee break coffee will be served for men and women in the lobby adjacent to exhibit area. please move into exhibit area. ( exhibits open ) trrf general session ( part | ) presiding : lee a. waller trrf vice president “ introductory remarks ” lee a. waller, trrf vice presi - dent individual interviews with trrf public board members and sci - entific advisory council mem - bers conducted by trrf treasurer philip g. kuehn to get answers which the public refrigerated warehousing industry is looking for. plus questions from'
  1110. ```
  1111. ```python
  1112. >>> target_start_index = torch.tensor([7])
  1113. >>> target_end_index = torch.tensor([14])
  1114. >>> outputs = model(**encoding, start_positions=target_start_index, end_positions=target_end_index)
  1115. >>> predicted_answer_span_start = outputs.start_logits.argmax(-1).item()
  1116. >>> predicted_answer_span_end = outputs.end_logits.argmax(-1).item()
  1117. >>> predicted_answer_span_start, predicted_answer_span_end
  1118. (30, 191)
  1119. ```
  1120. """
  1121. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1122. outputs = self.layoutlmv2(
  1123. input_ids=input_ids,
  1124. bbox=bbox,
  1125. image=image,
  1126. attention_mask=attention_mask,
  1127. token_type_ids=token_type_ids,
  1128. position_ids=position_ids,
  1129. head_mask=head_mask,
  1130. inputs_embeds=inputs_embeds,
  1131. output_attentions=output_attentions,
  1132. output_hidden_states=output_hidden_states,
  1133. return_dict=return_dict,
  1134. )
  1135. if input_ids is not None:
  1136. input_shape = input_ids.size()
  1137. else:
  1138. input_shape = inputs_embeds.size()[:-1]
  1139. seq_length = input_shape[1]
  1140. # only take the text part of the output representations
  1141. sequence_output = outputs[0][:, :seq_length]
  1142. logits = self.qa_outputs(sequence_output)
  1143. start_logits, end_logits = logits.split(1, dim=-1)
  1144. start_logits = start_logits.squeeze(-1).contiguous()
  1145. end_logits = end_logits.squeeze(-1).contiguous()
  1146. total_loss = None
  1147. if start_positions is not None and end_positions is not None:
  1148. # If we are on multi-GPU, split add a dimension
  1149. if len(start_positions.size()) > 1:
  1150. start_positions = start_positions.squeeze(-1)
  1151. if len(end_positions.size()) > 1:
  1152. end_positions = end_positions.squeeze(-1)
  1153. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  1154. ignored_index = start_logits.size(1)
  1155. start_positions = start_positions.clamp(0, ignored_index)
  1156. end_positions = end_positions.clamp(0, ignored_index)
  1157. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  1158. start_loss = loss_fct(start_logits, start_positions)
  1159. end_loss = loss_fct(end_logits, end_positions)
  1160. total_loss = (start_loss + end_loss) / 2
  1161. if not return_dict:
  1162. output = (start_logits, end_logits) + outputs[2:]
  1163. return ((total_loss,) + output) if total_loss is not None else output
  1164. return QuestionAnsweringModelOutput(
  1165. loss=total_loss,
  1166. start_logits=start_logits,
  1167. end_logits=end_logits,
  1168. hidden_states=outputs.hidden_states,
  1169. attentions=outputs.attentions,
  1170. )
  1171. __all__ = [
  1172. "LayoutLMv2ForQuestionAnswering",
  1173. "LayoutLMv2ForSequenceClassification",
  1174. "LayoutLMv2ForTokenClassification",
  1175. "LayoutLMv2Layer",
  1176. "LayoutLMv2Model",
  1177. "LayoutLMv2PreTrainedModel",
  1178. ]