modeling_markuplm.py 42 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051
  1. # coding=utf-8
  2. # Copyright 2022 Microsoft Research Asia and the HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch MarkupLM model."""
  16. import os
  17. from typing import Callable, Optional, Union
  18. import torch
  19. from torch import nn
  20. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  21. from ...activations import ACT2FN
  22. from ...modeling_layers import GradientCheckpointingLayer
  23. from ...modeling_outputs import (
  24. BaseModelOutput,
  25. BaseModelOutputWithPooling,
  26. MaskedLMOutput,
  27. QuestionAnsweringModelOutput,
  28. SequenceClassifierOutput,
  29. TokenClassifierOutput,
  30. )
  31. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  32. from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
  33. from ...utils import auto_docstring, can_return_tuple, logging
  34. from .configuration_markuplm import MarkupLMConfig
  35. logger = logging.get_logger(__name__)
  36. class XPathEmbeddings(nn.Module):
  37. """Construct the embeddings from xpath tags and subscripts.
  38. We drop tree-id in this version, as its info can be covered by xpath.
  39. """
  40. def __init__(self, config):
  41. super().__init__()
  42. self.max_depth = config.max_depth
  43. self.xpath_unitseq2_embeddings = nn.Linear(config.xpath_unit_hidden_size * self.max_depth, config.hidden_size)
  44. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  45. self.activation = nn.ReLU()
  46. self.xpath_unitseq2_inner = nn.Linear(config.xpath_unit_hidden_size * self.max_depth, 4 * config.hidden_size)
  47. self.inner2emb = nn.Linear(4 * config.hidden_size, config.hidden_size)
  48. self.xpath_tag_sub_embeddings = nn.ModuleList(
  49. [
  50. nn.Embedding(config.max_xpath_tag_unit_embeddings, config.xpath_unit_hidden_size)
  51. for _ in range(self.max_depth)
  52. ]
  53. )
  54. self.xpath_subs_sub_embeddings = nn.ModuleList(
  55. [
  56. nn.Embedding(config.max_xpath_subs_unit_embeddings, config.xpath_unit_hidden_size)
  57. for _ in range(self.max_depth)
  58. ]
  59. )
  60. def forward(self, xpath_tags_seq=None, xpath_subs_seq=None):
  61. xpath_tags_embeddings = []
  62. xpath_subs_embeddings = []
  63. for i in range(self.max_depth):
  64. xpath_tags_embeddings.append(self.xpath_tag_sub_embeddings[i](xpath_tags_seq[:, :, i]))
  65. xpath_subs_embeddings.append(self.xpath_subs_sub_embeddings[i](xpath_subs_seq[:, :, i]))
  66. xpath_tags_embeddings = torch.cat(xpath_tags_embeddings, dim=-1)
  67. xpath_subs_embeddings = torch.cat(xpath_subs_embeddings, dim=-1)
  68. xpath_embeddings = xpath_tags_embeddings + xpath_subs_embeddings
  69. xpath_embeddings = self.inner2emb(self.dropout(self.activation(self.xpath_unitseq2_inner(xpath_embeddings))))
  70. return xpath_embeddings
  71. # Copied from transformers.models.roberta.modeling_roberta.create_position_ids_from_input_ids
  72. def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
  73. """
  74. Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
  75. are ignored. This is modified from fairseq's `utils.make_positions`.
  76. Args:
  77. x: torch.Tensor x:
  78. Returns: torch.Tensor
  79. """
  80. # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
  81. mask = input_ids.ne(padding_idx).int()
  82. incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
  83. return incremental_indices.long() + padding_idx
  84. class MarkupLMEmbeddings(nn.Module):
  85. """Construct the embeddings from word, position and token_type embeddings."""
  86. def __init__(self, config):
  87. super().__init__()
  88. self.config = config
  89. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  90. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
  91. self.max_depth = config.max_depth
  92. self.xpath_embeddings = XPathEmbeddings(config)
  93. self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
  94. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  95. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  96. self.register_buffer(
  97. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  98. )
  99. self.padding_idx = config.pad_token_id
  100. self.position_embeddings = nn.Embedding(
  101. config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
  102. )
  103. # Copied from transformers.models.roberta.modeling_roberta.RobertaEmbeddings.create_position_ids_from_inputs_embeds
  104. def create_position_ids_from_inputs_embeds(self, inputs_embeds):
  105. """
  106. We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
  107. Args:
  108. inputs_embeds: torch.Tensor
  109. Returns: torch.Tensor
  110. """
  111. input_shape = inputs_embeds.size()[:-1]
  112. sequence_length = input_shape[1]
  113. position_ids = torch.arange(
  114. self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
  115. )
  116. return position_ids.unsqueeze(0).expand(input_shape)
  117. def forward(
  118. self,
  119. input_ids=None,
  120. xpath_tags_seq=None,
  121. xpath_subs_seq=None,
  122. token_type_ids=None,
  123. position_ids=None,
  124. inputs_embeds=None,
  125. past_key_values_length=0,
  126. ):
  127. if input_ids is not None:
  128. input_shape = input_ids.size()
  129. else:
  130. input_shape = inputs_embeds.size()[:-1]
  131. device = input_ids.device if input_ids is not None else inputs_embeds.device
  132. if position_ids is None:
  133. if input_ids is not None:
  134. # Create the position ids from the input token ids. Any padded tokens remain padded.
  135. position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length)
  136. else:
  137. position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
  138. if token_type_ids is None:
  139. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
  140. if inputs_embeds is None:
  141. inputs_embeds = self.word_embeddings(input_ids)
  142. # prepare xpath seq
  143. if xpath_tags_seq is None:
  144. xpath_tags_seq = self.config.tag_pad_id * torch.ones(
  145. tuple(list(input_shape) + [self.max_depth]), dtype=torch.long, device=device
  146. )
  147. if xpath_subs_seq is None:
  148. xpath_subs_seq = self.config.subs_pad_id * torch.ones(
  149. tuple(list(input_shape) + [self.max_depth]), dtype=torch.long, device=device
  150. )
  151. words_embeddings = inputs_embeds
  152. position_embeddings = self.position_embeddings(position_ids)
  153. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  154. xpath_embeddings = self.xpath_embeddings(xpath_tags_seq, xpath_subs_seq)
  155. embeddings = words_embeddings + position_embeddings + token_type_embeddings + xpath_embeddings
  156. embeddings = self.LayerNorm(embeddings)
  157. embeddings = self.dropout(embeddings)
  158. return embeddings
  159. # Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->MarkupLM
  160. class MarkupLMSelfOutput(nn.Module):
  161. def __init__(self, config):
  162. super().__init__()
  163. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  164. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  165. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  166. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  167. hidden_states = self.dense(hidden_states)
  168. hidden_states = self.dropout(hidden_states)
  169. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  170. return hidden_states
  171. # Copied from transformers.models.bert.modeling_bert.BertIntermediate
  172. class MarkupLMIntermediate(nn.Module):
  173. def __init__(self, config):
  174. super().__init__()
  175. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  176. if isinstance(config.hidden_act, str):
  177. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  178. else:
  179. self.intermediate_act_fn = config.hidden_act
  180. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  181. hidden_states = self.dense(hidden_states)
  182. hidden_states = self.intermediate_act_fn(hidden_states)
  183. return hidden_states
  184. # Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->MarkupLM
  185. class MarkupLMOutput(nn.Module):
  186. def __init__(self, config):
  187. super().__init__()
  188. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  189. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  190. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  191. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  192. hidden_states = self.dense(hidden_states)
  193. hidden_states = self.dropout(hidden_states)
  194. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  195. return hidden_states
  196. # Copied from transformers.models.bert.modeling_bert.BertPooler
  197. class MarkupLMPooler(nn.Module):
  198. def __init__(self, config):
  199. super().__init__()
  200. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  201. self.activation = nn.Tanh()
  202. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  203. # We "pool" the model by simply taking the hidden state corresponding
  204. # to the first token.
  205. first_token_tensor = hidden_states[:, 0]
  206. pooled_output = self.dense(first_token_tensor)
  207. pooled_output = self.activation(pooled_output)
  208. return pooled_output
  209. # Copied from transformers.models.bert.modeling_bert.BertPredictionHeadTransform with Bert->MarkupLM
  210. class MarkupLMPredictionHeadTransform(nn.Module):
  211. def __init__(self, config):
  212. super().__init__()
  213. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  214. if isinstance(config.hidden_act, str):
  215. self.transform_act_fn = ACT2FN[config.hidden_act]
  216. else:
  217. self.transform_act_fn = config.hidden_act
  218. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  219. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  220. hidden_states = self.dense(hidden_states)
  221. hidden_states = self.transform_act_fn(hidden_states)
  222. hidden_states = self.LayerNorm(hidden_states)
  223. return hidden_states
  224. # Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead with Bert->MarkupLM
  225. class MarkupLMLMPredictionHead(nn.Module):
  226. def __init__(self, config):
  227. super().__init__()
  228. self.transform = MarkupLMPredictionHeadTransform(config)
  229. # The output weights are the same as the input embeddings, but there is
  230. # an output-only bias for each token.
  231. self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  232. self.bias = nn.Parameter(torch.zeros(config.vocab_size))
  233. # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
  234. self.decoder.bias = self.bias
  235. def _tie_weights(self):
  236. self.decoder.bias = self.bias
  237. def forward(self, hidden_states):
  238. hidden_states = self.transform(hidden_states)
  239. hidden_states = self.decoder(hidden_states)
  240. return hidden_states
  241. # Copied from transformers.models.bert.modeling_bert.BertOnlyMLMHead with Bert->MarkupLM
  242. class MarkupLMOnlyMLMHead(nn.Module):
  243. def __init__(self, config):
  244. super().__init__()
  245. self.predictions = MarkupLMLMPredictionHead(config)
  246. def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
  247. prediction_scores = self.predictions(sequence_output)
  248. return prediction_scores
  249. # Copied from transformers.models.align.modeling_align.eager_attention_forward
  250. def eager_attention_forward(
  251. module: nn.Module,
  252. query: torch.Tensor,
  253. key: torch.Tensor,
  254. value: torch.Tensor,
  255. attention_mask: Optional[torch.Tensor],
  256. scaling: float,
  257. dropout: float = 0.0,
  258. head_mask: Optional[torch.Tensor] = None,
  259. **kwargs,
  260. ):
  261. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  262. if attention_mask is not None:
  263. causal_mask = attention_mask[:, :, :, : key.shape[-2]]
  264. attn_weights = attn_weights + causal_mask
  265. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  266. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  267. if head_mask is not None:
  268. attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
  269. attn_output = torch.matmul(attn_weights, value)
  270. attn_output = attn_output.transpose(1, 2).contiguous()
  271. return attn_output, attn_weights
  272. # Copied from transformers.models.align.modeling_align.AlignTextSelfAttention with AlignText->MarkupLM
  273. class MarkupLMSelfAttention(nn.Module):
  274. def __init__(self, config):
  275. super().__init__()
  276. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  277. raise ValueError(
  278. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  279. f"heads ({config.num_attention_heads})"
  280. )
  281. self.config = config
  282. self.num_attention_heads = config.num_attention_heads
  283. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  284. self.all_head_size = self.num_attention_heads * self.attention_head_size
  285. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  286. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  287. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  288. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  289. self.attention_dropout = config.attention_probs_dropout_prob
  290. self.scaling = self.attention_head_size**-0.5
  291. def forward(
  292. self,
  293. hidden_states: torch.Tensor,
  294. attention_mask: Optional[torch.FloatTensor] = None,
  295. head_mask: Optional[torch.FloatTensor] = None,
  296. output_attentions: Optional[bool] = False,
  297. **kwargs,
  298. ) -> tuple[torch.Tensor]:
  299. input_shape = hidden_states.shape[:-1]
  300. hidden_shape = (*input_shape, -1, self.attention_head_size)
  301. query_states = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
  302. key_states = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
  303. value_states = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
  304. attention_interface: Callable = eager_attention_forward
  305. if self.config._attn_implementation != "eager":
  306. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  307. attn_output, attn_weights = attention_interface(
  308. self,
  309. query_states,
  310. key_states,
  311. value_states,
  312. attention_mask,
  313. dropout=0.0 if not self.training else self.attention_dropout,
  314. scaling=self.scaling,
  315. head_mask=head_mask,
  316. **kwargs,
  317. )
  318. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  319. outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
  320. return outputs
  321. # Copied from transformers.models.align.modeling_align.AlignTextAttention with AlignText->MarkupLM
  322. class MarkupLMAttention(nn.Module):
  323. def __init__(self, config):
  324. super().__init__()
  325. self.self = MarkupLMSelfAttention(config)
  326. self.output = MarkupLMSelfOutput(config)
  327. self.pruned_heads = set()
  328. def prune_heads(self, heads):
  329. if len(heads) == 0:
  330. return
  331. heads, index = find_pruneable_heads_and_indices(
  332. heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
  333. )
  334. # Prune linear layers
  335. self.self.query = prune_linear_layer(self.self.query, index)
  336. self.self.key = prune_linear_layer(self.self.key, index)
  337. self.self.value = prune_linear_layer(self.self.value, index)
  338. self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
  339. # Update hyper params and store pruned heads
  340. self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
  341. self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
  342. self.pruned_heads = self.pruned_heads.union(heads)
  343. def forward(
  344. self,
  345. hidden_states: torch.Tensor,
  346. attention_mask: Optional[torch.FloatTensor] = None,
  347. head_mask: Optional[torch.FloatTensor] = None,
  348. output_attentions: Optional[bool] = False,
  349. **kwargs,
  350. ) -> tuple[torch.Tensor]:
  351. self_outputs = self.self(
  352. hidden_states,
  353. attention_mask=attention_mask,
  354. head_mask=head_mask,
  355. output_attentions=output_attentions,
  356. **kwargs,
  357. )
  358. attention_output = self.output(self_outputs[0], hidden_states)
  359. outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
  360. return outputs
  361. # Copied from transformers.models.align.modeling_align.AlignTextLayer with AlignText->MarkupLM
  362. class MarkupLMLayer(GradientCheckpointingLayer):
  363. def __init__(self, config):
  364. super().__init__()
  365. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  366. self.seq_len_dim = 1
  367. self.attention = MarkupLMAttention(config)
  368. self.intermediate = MarkupLMIntermediate(config)
  369. self.output = MarkupLMOutput(config)
  370. def forward(
  371. self,
  372. hidden_states: torch.Tensor,
  373. attention_mask: Optional[torch.FloatTensor] = None,
  374. head_mask: Optional[torch.FloatTensor] = None,
  375. output_attentions: Optional[bool] = False,
  376. **kwargs,
  377. ) -> tuple[torch.Tensor]:
  378. self_attention_outputs = self.attention(
  379. hidden_states,
  380. attention_mask=attention_mask,
  381. head_mask=head_mask,
  382. output_attentions=output_attentions,
  383. **kwargs,
  384. )
  385. attention_output = self_attention_outputs[0]
  386. outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
  387. layer_output = apply_chunking_to_forward(
  388. self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
  389. )
  390. outputs = (layer_output,) + outputs
  391. return outputs
  392. def feed_forward_chunk(self, attention_output):
  393. intermediate_output = self.intermediate(attention_output)
  394. layer_output = self.output(intermediate_output, attention_output)
  395. return layer_output
  396. # Copied from transformers.models.align.modeling_align.AlignTextEncoder with AlignText->MarkupLM
  397. class MarkupLMEncoder(nn.Module):
  398. def __init__(self, config):
  399. super().__init__()
  400. self.config = config
  401. self.layer = nn.ModuleList([MarkupLMLayer(config) for i in range(config.num_hidden_layers)])
  402. self.gradient_checkpointing = False
  403. @can_return_tuple
  404. def forward(
  405. self,
  406. hidden_states: torch.Tensor,
  407. attention_mask: Optional[torch.FloatTensor] = None,
  408. head_mask: Optional[torch.FloatTensor] = None,
  409. output_attentions: Optional[bool] = False,
  410. output_hidden_states: Optional[bool] = False,
  411. return_dict: Optional[bool] = True,
  412. **kwargs,
  413. ) -> Union[tuple[torch.Tensor], BaseModelOutput]:
  414. all_hidden_states = () if output_hidden_states else None
  415. all_self_attentions = () if output_attentions else None
  416. for i, layer_module in enumerate(self.layer):
  417. if output_hidden_states:
  418. all_hidden_states = all_hidden_states + (hidden_states,)
  419. layer_head_mask = head_mask[i] if head_mask is not None else None
  420. layer_outputs = layer_module(
  421. hidden_states=hidden_states,
  422. attention_mask=attention_mask,
  423. head_mask=layer_head_mask,
  424. output_attentions=output_attentions,
  425. **kwargs,
  426. )
  427. hidden_states = layer_outputs[0]
  428. if output_attentions:
  429. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  430. if output_hidden_states:
  431. all_hidden_states = all_hidden_states + (hidden_states,)
  432. return BaseModelOutput(
  433. last_hidden_state=hidden_states,
  434. hidden_states=all_hidden_states,
  435. attentions=all_self_attentions,
  436. )
  437. @auto_docstring
  438. class MarkupLMPreTrainedModel(PreTrainedModel):
  439. config: MarkupLMConfig
  440. base_model_prefix = "markuplm"
  441. # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights with Bert->MarkupLM
  442. def _init_weights(self, module):
  443. """Initialize the weights"""
  444. if isinstance(module, nn.Linear):
  445. # Slightly different from the TF version which uses truncated_normal for initialization
  446. # cf https://github.com/pytorch/pytorch/pull/5617
  447. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  448. if module.bias is not None:
  449. module.bias.data.zero_()
  450. elif isinstance(module, nn.Embedding):
  451. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  452. if module.padding_idx is not None:
  453. module.weight.data[module.padding_idx].zero_()
  454. elif isinstance(module, nn.LayerNorm):
  455. module.bias.data.zero_()
  456. module.weight.data.fill_(1.0)
  457. elif isinstance(module, MarkupLMLMPredictionHead):
  458. module.bias.data.zero_()
  459. @classmethod
  460. def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
  461. return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
  462. @auto_docstring
  463. class MarkupLMModel(MarkupLMPreTrainedModel):
  464. # Copied from transformers.models.clap.modeling_clap.ClapTextModel.__init__ with ClapText->MarkupLM
  465. def __init__(self, config, add_pooling_layer=True):
  466. r"""
  467. add_pooling_layer (bool, *optional*, defaults to `True`):
  468. Whether to add a pooling layer
  469. """
  470. super().__init__(config)
  471. self.config = config
  472. self.embeddings = MarkupLMEmbeddings(config)
  473. self.encoder = MarkupLMEncoder(config)
  474. self.pooler = MarkupLMPooler(config) if add_pooling_layer else None
  475. # Initialize weights and apply final processing
  476. self.post_init()
  477. def get_input_embeddings(self):
  478. return self.embeddings.word_embeddings
  479. def set_input_embeddings(self, value):
  480. self.embeddings.word_embeddings = value
  481. def _prune_heads(self, heads_to_prune):
  482. """
  483. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  484. class PreTrainedModel
  485. """
  486. for layer, heads in heads_to_prune.items():
  487. self.encoder.layer[layer].attention.prune_heads(heads)
  488. @can_return_tuple
  489. @auto_docstring
  490. def forward(
  491. self,
  492. input_ids: Optional[torch.LongTensor] = None,
  493. xpath_tags_seq: Optional[torch.LongTensor] = None,
  494. xpath_subs_seq: Optional[torch.LongTensor] = None,
  495. attention_mask: Optional[torch.FloatTensor] = None,
  496. token_type_ids: Optional[torch.LongTensor] = None,
  497. position_ids: Optional[torch.LongTensor] = None,
  498. head_mask: Optional[torch.FloatTensor] = None,
  499. inputs_embeds: Optional[torch.FloatTensor] = None,
  500. output_attentions: Optional[bool] = None,
  501. output_hidden_states: Optional[bool] = None,
  502. return_dict: Optional[bool] = None,
  503. ) -> Union[tuple, BaseModelOutputWithPooling]:
  504. r"""
  505. xpath_tags_seq (`torch.LongTensor` of shape `(batch_size, sequence_length, config.max_depth)`, *optional*):
  506. Tag IDs for each token in the input sequence, padded up to config.max_depth.
  507. xpath_subs_seq (`torch.LongTensor` of shape `(batch_size, sequence_length, config.max_depth)`, *optional*):
  508. Subscript IDs for each token in the input sequence, padded up to config.max_depth.
  509. Examples:
  510. ```python
  511. >>> from transformers import AutoProcessor, MarkupLMModel
  512. >>> processor = AutoProcessor.from_pretrained("microsoft/markuplm-base")
  513. >>> model = MarkupLMModel.from_pretrained("microsoft/markuplm-base")
  514. >>> html_string = "<html> <head> <title>Page Title</title> </head> </html>"
  515. >>> encoding = processor(html_string, return_tensors="pt")
  516. >>> outputs = model(**encoding)
  517. >>> last_hidden_states = outputs.last_hidden_state
  518. >>> list(last_hidden_states.shape)
  519. [1, 4, 768]
  520. ```"""
  521. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  522. output_hidden_states = (
  523. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  524. )
  525. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  526. if input_ids is not None and inputs_embeds is not None:
  527. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  528. elif input_ids is not None:
  529. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  530. input_shape = input_ids.size()
  531. elif inputs_embeds is not None:
  532. input_shape = inputs_embeds.size()[:-1]
  533. else:
  534. raise ValueError("You have to specify either input_ids or inputs_embeds")
  535. device = input_ids.device if input_ids is not None else inputs_embeds.device
  536. if attention_mask is None:
  537. attention_mask = torch.ones(input_shape, device=device)
  538. if token_type_ids is None:
  539. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
  540. extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
  541. extended_attention_mask = extended_attention_mask.to(dtype=self.dtype)
  542. extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
  543. if head_mask is not None:
  544. if head_mask.dim() == 1:
  545. head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
  546. head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
  547. elif head_mask.dim() == 2:
  548. head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
  549. head_mask = head_mask.to(dtype=next(self.parameters()).dtype)
  550. else:
  551. head_mask = [None] * self.config.num_hidden_layers
  552. embedding_output = self.embeddings(
  553. input_ids=input_ids,
  554. xpath_tags_seq=xpath_tags_seq,
  555. xpath_subs_seq=xpath_subs_seq,
  556. position_ids=position_ids,
  557. token_type_ids=token_type_ids,
  558. inputs_embeds=inputs_embeds,
  559. )
  560. encoder_outputs = self.encoder(
  561. embedding_output,
  562. extended_attention_mask,
  563. head_mask=head_mask,
  564. output_attentions=output_attentions,
  565. output_hidden_states=output_hidden_states,
  566. return_dict=True,
  567. )
  568. sequence_output = encoder_outputs[0]
  569. pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
  570. return BaseModelOutputWithPooling(
  571. last_hidden_state=sequence_output,
  572. pooler_output=pooled_output,
  573. hidden_states=encoder_outputs.hidden_states,
  574. attentions=encoder_outputs.attentions,
  575. )
  576. @auto_docstring
  577. class MarkupLMForQuestionAnswering(MarkupLMPreTrainedModel):
  578. # Copied from transformers.models.bert.modeling_bert.BertForQuestionAnswering.__init__ with bert->markuplm, Bert->MarkupLM
  579. def __init__(self, config):
  580. super().__init__(config)
  581. self.num_labels = config.num_labels
  582. self.markuplm = MarkupLMModel(config, add_pooling_layer=False)
  583. self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
  584. # Initialize weights and apply final processing
  585. self.post_init()
  586. @can_return_tuple
  587. @auto_docstring
  588. def forward(
  589. self,
  590. input_ids: Optional[torch.Tensor] = None,
  591. xpath_tags_seq: Optional[torch.Tensor] = None,
  592. xpath_subs_seq: Optional[torch.Tensor] = None,
  593. attention_mask: Optional[torch.Tensor] = None,
  594. token_type_ids: Optional[torch.Tensor] = None,
  595. position_ids: Optional[torch.Tensor] = None,
  596. head_mask: Optional[torch.Tensor] = None,
  597. inputs_embeds: Optional[torch.Tensor] = None,
  598. start_positions: Optional[torch.Tensor] = None,
  599. end_positions: Optional[torch.Tensor] = None,
  600. output_attentions: Optional[bool] = None,
  601. output_hidden_states: Optional[bool] = None,
  602. return_dict: Optional[bool] = None,
  603. ) -> Union[tuple[torch.Tensor], QuestionAnsweringModelOutput]:
  604. r"""
  605. xpath_tags_seq (`torch.LongTensor` of shape `(batch_size, sequence_length, config.max_depth)`, *optional*):
  606. Tag IDs for each token in the input sequence, padded up to config.max_depth.
  607. xpath_subs_seq (`torch.LongTensor` of shape `(batch_size, sequence_length, config.max_depth)`, *optional*):
  608. Subscript IDs for each token in the input sequence, padded up to config.max_depth.
  609. Examples:
  610. ```python
  611. >>> from transformers import AutoProcessor, MarkupLMForQuestionAnswering
  612. >>> import torch
  613. >>> processor = AutoProcessor.from_pretrained("microsoft/markuplm-base-finetuned-websrc")
  614. >>> model = MarkupLMForQuestionAnswering.from_pretrained("microsoft/markuplm-base-finetuned-websrc")
  615. >>> html_string = "<html> <head> <title>My name is Niels</title> </head> </html>"
  616. >>> question = "What's his name?"
  617. >>> encoding = processor(html_string, questions=question, return_tensors="pt")
  618. >>> with torch.no_grad():
  619. ... outputs = model(**encoding)
  620. >>> answer_start_index = outputs.start_logits.argmax()
  621. >>> answer_end_index = outputs.end_logits.argmax()
  622. >>> predict_answer_tokens = encoding.input_ids[0, answer_start_index : answer_end_index + 1]
  623. >>> processor.decode(predict_answer_tokens).strip()
  624. 'Niels'
  625. ```"""
  626. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  627. outputs = self.markuplm(
  628. input_ids,
  629. xpath_tags_seq=xpath_tags_seq,
  630. xpath_subs_seq=xpath_subs_seq,
  631. attention_mask=attention_mask,
  632. token_type_ids=token_type_ids,
  633. position_ids=position_ids,
  634. head_mask=head_mask,
  635. inputs_embeds=inputs_embeds,
  636. output_attentions=output_attentions,
  637. output_hidden_states=output_hidden_states,
  638. return_dict=True,
  639. )
  640. sequence_output = outputs[0]
  641. logits = self.qa_outputs(sequence_output)
  642. start_logits, end_logits = logits.split(1, dim=-1)
  643. start_logits = start_logits.squeeze(-1).contiguous()
  644. end_logits = end_logits.squeeze(-1).contiguous()
  645. total_loss = None
  646. if start_positions is not None and end_positions is not None:
  647. # If we are on multi-GPU, split add a dimension
  648. if len(start_positions.size()) > 1:
  649. start_positions = start_positions.squeeze(-1)
  650. if len(end_positions.size()) > 1:
  651. end_positions = end_positions.squeeze(-1)
  652. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  653. ignored_index = start_logits.size(1)
  654. start_positions.clamp_(0, ignored_index)
  655. end_positions.clamp_(0, ignored_index)
  656. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  657. start_loss = loss_fct(start_logits, start_positions)
  658. end_loss = loss_fct(end_logits, end_positions)
  659. total_loss = (start_loss + end_loss) / 2
  660. return QuestionAnsweringModelOutput(
  661. loss=total_loss,
  662. start_logits=start_logits,
  663. end_logits=end_logits,
  664. hidden_states=outputs.hidden_states,
  665. attentions=outputs.attentions,
  666. )
  667. @auto_docstring(
  668. custom_intro="""
  669. MarkupLM Model with a `token_classification` head on top.
  670. """
  671. )
  672. class MarkupLMForTokenClassification(MarkupLMPreTrainedModel):
  673. # Copied from transformers.models.bert.modeling_bert.BertForTokenClassification.__init__ with bert->markuplm, Bert->MarkupLM
  674. def __init__(self, config):
  675. super().__init__(config)
  676. self.num_labels = config.num_labels
  677. self.markuplm = MarkupLMModel(config, add_pooling_layer=False)
  678. classifier_dropout = (
  679. config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
  680. )
  681. self.dropout = nn.Dropout(classifier_dropout)
  682. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  683. # Initialize weights and apply final processing
  684. self.post_init()
  685. @can_return_tuple
  686. @auto_docstring
  687. def forward(
  688. self,
  689. input_ids: Optional[torch.Tensor] = None,
  690. xpath_tags_seq: Optional[torch.Tensor] = None,
  691. xpath_subs_seq: Optional[torch.Tensor] = None,
  692. attention_mask: Optional[torch.Tensor] = None,
  693. token_type_ids: Optional[torch.Tensor] = None,
  694. position_ids: Optional[torch.Tensor] = None,
  695. head_mask: Optional[torch.Tensor] = None,
  696. inputs_embeds: Optional[torch.Tensor] = None,
  697. labels: Optional[torch.Tensor] = None,
  698. output_attentions: Optional[bool] = None,
  699. output_hidden_states: Optional[bool] = None,
  700. return_dict: Optional[bool] = None,
  701. ) -> Union[tuple[torch.Tensor], MaskedLMOutput]:
  702. r"""
  703. xpath_tags_seq (`torch.LongTensor` of shape `(batch_size, sequence_length, config.max_depth)`, *optional*):
  704. Tag IDs for each token in the input sequence, padded up to config.max_depth.
  705. xpath_subs_seq (`torch.LongTensor` of shape `(batch_size, sequence_length, config.max_depth)`, *optional*):
  706. Subscript IDs for each token in the input sequence, padded up to config.max_depth.
  707. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  708. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  709. Examples:
  710. ```python
  711. >>> from transformers import AutoProcessor, AutoModelForTokenClassification
  712. >>> import torch
  713. >>> processor = AutoProcessor.from_pretrained("microsoft/markuplm-base")
  714. >>> processor.parse_html = False
  715. >>> model = AutoModelForTokenClassification.from_pretrained("microsoft/markuplm-base", num_labels=7)
  716. >>> nodes = ["hello", "world"]
  717. >>> xpaths = ["/html/body/div/li[1]/div/span", "/html/body/div/li[1]/div/span"]
  718. >>> node_labels = [1, 2]
  719. >>> encoding = processor(nodes=nodes, xpaths=xpaths, node_labels=node_labels, return_tensors="pt")
  720. >>> with torch.no_grad():
  721. ... outputs = model(**encoding)
  722. >>> loss = outputs.loss
  723. >>> logits = outputs.logits
  724. ```"""
  725. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  726. outputs = self.markuplm(
  727. input_ids,
  728. xpath_tags_seq=xpath_tags_seq,
  729. xpath_subs_seq=xpath_subs_seq,
  730. attention_mask=attention_mask,
  731. token_type_ids=token_type_ids,
  732. position_ids=position_ids,
  733. head_mask=head_mask,
  734. inputs_embeds=inputs_embeds,
  735. output_attentions=output_attentions,
  736. output_hidden_states=output_hidden_states,
  737. return_dict=True,
  738. )
  739. sequence_output = outputs[0]
  740. prediction_scores = self.classifier(sequence_output) # (batch_size, seq_length, node_type_size)
  741. loss = None
  742. if labels is not None:
  743. loss_fct = CrossEntropyLoss()
  744. loss = loss_fct(
  745. prediction_scores.view(-1, self.config.num_labels),
  746. labels.view(-1),
  747. )
  748. return TokenClassifierOutput(
  749. loss=loss,
  750. logits=prediction_scores,
  751. hidden_states=outputs.hidden_states,
  752. attentions=outputs.attentions,
  753. )
  754. @auto_docstring(
  755. custom_intro="""
  756. MarkupLM Model transformer with a sequence classification/regression head on top (a linear layer on top of the
  757. pooled output) e.g. for GLUE tasks.
  758. """
  759. )
  760. class MarkupLMForSequenceClassification(MarkupLMPreTrainedModel):
  761. # Copied from transformers.models.bert.modeling_bert.BertForSequenceClassification.__init__ with bert->markuplm, Bert->MarkupLM
  762. def __init__(self, config):
  763. super().__init__(config)
  764. self.num_labels = config.num_labels
  765. self.config = config
  766. self.markuplm = MarkupLMModel(config)
  767. classifier_dropout = (
  768. config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
  769. )
  770. self.dropout = nn.Dropout(classifier_dropout)
  771. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  772. # Initialize weights and apply final processing
  773. self.post_init()
  774. @can_return_tuple
  775. @auto_docstring
  776. def forward(
  777. self,
  778. input_ids: Optional[torch.Tensor] = None,
  779. xpath_tags_seq: Optional[torch.Tensor] = None,
  780. xpath_subs_seq: Optional[torch.Tensor] = None,
  781. attention_mask: Optional[torch.Tensor] = None,
  782. token_type_ids: Optional[torch.Tensor] = None,
  783. position_ids: Optional[torch.Tensor] = None,
  784. head_mask: Optional[torch.Tensor] = None,
  785. inputs_embeds: Optional[torch.Tensor] = None,
  786. labels: Optional[torch.Tensor] = None,
  787. output_attentions: Optional[bool] = None,
  788. output_hidden_states: Optional[bool] = None,
  789. return_dict: Optional[bool] = None,
  790. ) -> Union[tuple[torch.Tensor], SequenceClassifierOutput]:
  791. r"""
  792. xpath_tags_seq (`torch.LongTensor` of shape `(batch_size, sequence_length, config.max_depth)`, *optional*):
  793. Tag IDs for each token in the input sequence, padded up to config.max_depth.
  794. xpath_subs_seq (`torch.LongTensor` of shape `(batch_size, sequence_length, config.max_depth)`, *optional*):
  795. Subscript IDs for each token in the input sequence, padded up to config.max_depth.
  796. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  797. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  798. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  799. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  800. Examples:
  801. ```python
  802. >>> from transformers import AutoProcessor, AutoModelForSequenceClassification
  803. >>> import torch
  804. >>> processor = AutoProcessor.from_pretrained("microsoft/markuplm-base")
  805. >>> model = AutoModelForSequenceClassification.from_pretrained("microsoft/markuplm-base", num_labels=7)
  806. >>> html_string = "<html> <head> <title>Page Title</title> </head> </html>"
  807. >>> encoding = processor(html_string, return_tensors="pt")
  808. >>> with torch.no_grad():
  809. ... outputs = model(**encoding)
  810. >>> loss = outputs.loss
  811. >>> logits = outputs.logits
  812. ```"""
  813. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  814. outputs = self.markuplm(
  815. input_ids,
  816. xpath_tags_seq=xpath_tags_seq,
  817. xpath_subs_seq=xpath_subs_seq,
  818. attention_mask=attention_mask,
  819. token_type_ids=token_type_ids,
  820. position_ids=position_ids,
  821. head_mask=head_mask,
  822. inputs_embeds=inputs_embeds,
  823. output_attentions=output_attentions,
  824. output_hidden_states=output_hidden_states,
  825. return_dict=True,
  826. )
  827. pooled_output = outputs[1]
  828. pooled_output = self.dropout(pooled_output)
  829. logits = self.classifier(pooled_output)
  830. loss = None
  831. if labels is not None:
  832. if self.config.problem_type is None:
  833. if self.num_labels == 1:
  834. self.config.problem_type = "regression"
  835. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  836. self.config.problem_type = "single_label_classification"
  837. else:
  838. self.config.problem_type = "multi_label_classification"
  839. if self.config.problem_type == "regression":
  840. loss_fct = MSELoss()
  841. if self.num_labels == 1:
  842. loss = loss_fct(logits.squeeze(), labels.squeeze())
  843. else:
  844. loss = loss_fct(logits, labels)
  845. elif self.config.problem_type == "single_label_classification":
  846. loss_fct = CrossEntropyLoss()
  847. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  848. elif self.config.problem_type == "multi_label_classification":
  849. loss_fct = BCEWithLogitsLoss()
  850. loss = loss_fct(logits, labels)
  851. return SequenceClassifierOutput(
  852. loss=loss,
  853. logits=logits,
  854. hidden_states=outputs.hidden_states,
  855. attentions=outputs.attentions,
  856. )
  857. __all__ = [
  858. "MarkupLMForQuestionAnswering",
  859. "MarkupLMForSequenceClassification",
  860. "MarkupLMForTokenClassification",
  861. "MarkupLMModel",
  862. "MarkupLMPreTrainedModel",
  863. ]