modeling_bros.py 48 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136
  1. # coding=utf-8
  2. # Copyright 2023-present NAVER Corp, The Microsoft Research Asia LayoutLM Team Authors and the HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch Bros model."""
  16. import math
  17. from dataclasses import dataclass
  18. from typing import Optional, Union
  19. import torch
  20. from torch import nn
  21. from torch.nn import CrossEntropyLoss
  22. from ...activations import ACT2FN
  23. from ...modeling_layers import GradientCheckpointingLayer
  24. from ...modeling_outputs import (
  25. BaseModelOutputWithCrossAttentions,
  26. BaseModelOutputWithPoolingAndCrossAttentions,
  27. TokenClassifierOutput,
  28. )
  29. from ...modeling_utils import PreTrainedModel
  30. from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
  31. from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging
  32. from .configuration_bros import BrosConfig
  33. logger = logging.get_logger(__name__)
  34. @dataclass
  35. @auto_docstring(
  36. custom_intro="""
  37. Base class for outputs of token classification models.
  38. """
  39. )
  40. class BrosSpadeOutput(ModelOutput):
  41. r"""
  42. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  43. Classification loss.
  44. initial_token_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`):
  45. Classification scores for entity initial tokens (before SoftMax).
  46. subsequent_token_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, sequence_length+1)`):
  47. Classification scores for entity sequence tokens (before SoftMax).
  48. """
  49. loss: Optional[torch.FloatTensor] = None
  50. initial_token_logits: Optional[torch.FloatTensor] = None
  51. subsequent_token_logits: Optional[torch.FloatTensor] = None
  52. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  53. attentions: Optional[tuple[torch.FloatTensor]] = None
  54. class BrosPositionalEmbedding1D(nn.Module):
  55. # Reference: https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/mem_transformer.py#L15
  56. def __init__(self, config):
  57. super().__init__()
  58. self.dim_bbox_sinusoid_emb_1d = config.dim_bbox_sinusoid_emb_1d
  59. inv_freq = 1 / (
  60. 10000 ** (torch.arange(0.0, self.dim_bbox_sinusoid_emb_1d, 2.0) / self.dim_bbox_sinusoid_emb_1d)
  61. )
  62. self.register_buffer("inv_freq", inv_freq)
  63. def forward(self, pos_seq: torch.Tensor) -> torch.Tensor:
  64. seq_size = pos_seq.size()
  65. b1, b2, b3 = seq_size
  66. sinusoid_inp = pos_seq.view(b1, b2, b3, 1) * self.inv_freq.view(1, 1, 1, self.dim_bbox_sinusoid_emb_1d // 2)
  67. pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)
  68. return pos_emb
  69. class BrosPositionalEmbedding2D(nn.Module):
  70. def __init__(self, config):
  71. super().__init__()
  72. self.dim_bbox = config.dim_bbox
  73. self.x_pos_emb = BrosPositionalEmbedding1D(config)
  74. self.y_pos_emb = BrosPositionalEmbedding1D(config)
  75. def forward(self, bbox: torch.Tensor) -> torch.Tensor:
  76. stack = []
  77. for i in range(self.dim_bbox):
  78. if i % 2 == 0:
  79. stack.append(self.x_pos_emb(bbox[..., i]))
  80. else:
  81. stack.append(self.y_pos_emb(bbox[..., i]))
  82. bbox_pos_emb = torch.cat(stack, dim=-1)
  83. return bbox_pos_emb
  84. class BrosBboxEmbeddings(nn.Module):
  85. def __init__(self, config):
  86. super().__init__()
  87. self.bbox_sinusoid_emb = BrosPositionalEmbedding2D(config)
  88. self.bbox_projection = nn.Linear(config.dim_bbox_sinusoid_emb_2d, config.dim_bbox_projection, bias=False)
  89. def forward(self, bbox: torch.Tensor):
  90. bbox_t = bbox.transpose(0, 1)
  91. bbox_pos = bbox_t[None, :, :, :] - bbox_t[:, None, :, :]
  92. bbox_pos_emb = self.bbox_sinusoid_emb(bbox_pos)
  93. bbox_pos_emb = self.bbox_projection(bbox_pos_emb)
  94. return bbox_pos_emb
  95. class BrosTextEmbeddings(nn.Module):
  96. """Construct the embeddings from word, position and token_type embeddings."""
  97. def __init__(self, config):
  98. super().__init__()
  99. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  100. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
  101. self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
  102. # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
  103. # any TensorFlow checkpoint file
  104. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  105. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  106. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  107. self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
  108. self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
  109. self.register_buffer(
  110. "token_type_ids",
  111. torch.zeros(
  112. self.position_ids.size(),
  113. dtype=torch.long,
  114. device=self.position_ids.device,
  115. ),
  116. persistent=False,
  117. )
  118. def forward(
  119. self,
  120. input_ids: Optional[torch.Tensor] = None,
  121. token_type_ids: Optional[torch.Tensor] = None,
  122. position_ids: Optional[torch.Tensor] = None,
  123. inputs_embeds: Optional[torch.Tensor] = None,
  124. ) -> torch.Tensor:
  125. if input_ids is not None:
  126. input_shape = input_ids.size()
  127. else:
  128. input_shape = inputs_embeds.size()[:-1]
  129. seq_length = input_shape[1]
  130. if position_ids is None:
  131. position_ids = self.position_ids[:, :seq_length]
  132. if token_type_ids is None:
  133. if hasattr(self, "token_type_ids"):
  134. buffered_token_type_ids = self.token_type_ids[:, :seq_length]
  135. buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
  136. token_type_ids = buffered_token_type_ids_expanded
  137. else:
  138. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
  139. if inputs_embeds is None:
  140. inputs_embeds = self.word_embeddings(input_ids)
  141. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  142. embeddings = inputs_embeds + token_type_embeddings
  143. if self.position_embedding_type == "absolute":
  144. position_embeddings = self.position_embeddings(position_ids)
  145. embeddings += position_embeddings
  146. embeddings = self.LayerNorm(embeddings)
  147. embeddings = self.dropout(embeddings)
  148. return embeddings
  149. class BrosSelfAttention(nn.Module):
  150. def __init__(self, config):
  151. super().__init__()
  152. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  153. raise ValueError(
  154. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  155. f"heads ({config.num_attention_heads})"
  156. )
  157. self.num_attention_heads = config.num_attention_heads
  158. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  159. self.all_head_size = self.num_attention_heads * self.attention_head_size
  160. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  161. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  162. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  163. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  164. self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
  165. if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
  166. self.max_position_embeddings = config.max_position_embeddings
  167. self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
  168. self.is_decoder = config.is_decoder
  169. def forward(
  170. self,
  171. hidden_states: torch.Tensor,
  172. bbox_pos_emb: torch.Tensor,
  173. attention_mask: Optional[torch.Tensor] = None,
  174. head_mask: Optional[torch.Tensor] = None,
  175. encoder_hidden_states: Optional[torch.Tensor] = None,
  176. encoder_attention_mask: Optional[torch.Tensor] = None,
  177. output_attentions: Optional[torch.Tensor] = False,
  178. ) -> tuple[torch.Tensor]:
  179. hidden_shape = (hidden_states.shape[0], -1, self.num_attention_heads, self.attention_head_size)
  180. query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
  181. # If this is instantiated as a cross-attention module, the keys
  182. # and values come from an encoder; the attention mask needs to be
  183. # such that the encoder's padding tokens are not attended to.
  184. is_cross_attention = encoder_hidden_states is not None
  185. if is_cross_attention:
  186. key_layer = self.key(encoder_hidden_states).view(hidden_shape).transpose(1, 2)
  187. value_layer = self.value(encoder_hidden_states).view(hidden_shape).transpose(1, 2)
  188. attention_mask = encoder_attention_mask
  189. else:
  190. key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
  191. value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
  192. # Take the dot product between "query" and "key" to get the raw attention scores.
  193. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  194. if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
  195. seq_length = hidden_states.size()[1]
  196. position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
  197. position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
  198. distance = position_ids_l - position_ids_r
  199. positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
  200. positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
  201. if self.position_embedding_type == "relative_key":
  202. relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
  203. attention_scores = attention_scores + relative_position_scores
  204. elif self.position_embedding_type == "relative_key_query":
  205. relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
  206. relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
  207. attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
  208. # bbox positional encoding
  209. batch_size, n_head, seq_length, d_head = query_layer.shape
  210. bbox_pos_emb = bbox_pos_emb.view(seq_length, seq_length, batch_size, d_head)
  211. bbox_pos_emb = bbox_pos_emb.permute([2, 0, 1, 3])
  212. bbox_pos_scores = torch.einsum("bnid,bijd->bnij", (query_layer, bbox_pos_emb))
  213. attention_scores = attention_scores + bbox_pos_scores
  214. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  215. if attention_mask is not None:
  216. # Apply the attention mask is (precomputed for all layers in BrosModel forward() function)
  217. attention_scores = attention_scores + attention_mask
  218. # Normalize the attention scores to probabilities.
  219. attention_probs = nn.Softmax(dim=-1)(attention_scores)
  220. # This is actually dropping out entire tokens to attend to, which might
  221. # seem a bit unusual, but is taken from the original Transformer paper.
  222. attention_probs = self.dropout(attention_probs)
  223. # Mask heads if we want to
  224. if head_mask is not None:
  225. attention_probs = attention_probs * head_mask
  226. context_layer = torch.matmul(attention_probs, value_layer)
  227. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  228. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  229. context_layer = context_layer.view(*new_context_layer_shape)
  230. outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
  231. if self.is_decoder:
  232. outputs = outputs + (None,)
  233. return outputs
  234. # Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->Bros
  235. class BrosSelfOutput(nn.Module):
  236. def __init__(self, config):
  237. super().__init__()
  238. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  239. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  240. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  241. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  242. hidden_states = self.dense(hidden_states)
  243. hidden_states = self.dropout(hidden_states)
  244. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  245. return hidden_states
  246. class BrosAttention(nn.Module):
  247. def __init__(self, config):
  248. super().__init__()
  249. self.self = BrosSelfAttention(config)
  250. self.output = BrosSelfOutput(config)
  251. self.pruned_heads = set()
  252. def prune_heads(self, heads):
  253. if len(heads) == 0:
  254. return
  255. heads, index = find_pruneable_heads_and_indices(
  256. heads,
  257. self.self.num_attention_heads,
  258. self.self.attention_head_size,
  259. self.pruned_heads,
  260. )
  261. # Prune linear layers
  262. self.self.query = prune_linear_layer(self.self.query, index)
  263. self.self.key = prune_linear_layer(self.self.key, index)
  264. self.self.value = prune_linear_layer(self.self.value, index)
  265. self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
  266. # Update hyper params and store pruned heads
  267. self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
  268. self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
  269. self.pruned_heads = self.pruned_heads.union(heads)
  270. def forward(
  271. self,
  272. hidden_states: torch.Tensor,
  273. bbox_pos_emb: torch.Tensor,
  274. attention_mask: Optional[torch.Tensor] = None,
  275. head_mask: Optional[torch.Tensor] = None,
  276. encoder_hidden_states: Optional[torch.Tensor] = None,
  277. encoder_attention_mask: Optional[torch.Tensor] = None,
  278. output_attentions: Optional[bool] = False,
  279. ) -> tuple[torch.Tensor]:
  280. self_outputs = self.self(
  281. hidden_states=hidden_states,
  282. bbox_pos_emb=bbox_pos_emb,
  283. attention_mask=attention_mask,
  284. head_mask=head_mask,
  285. encoder_hidden_states=encoder_hidden_states,
  286. encoder_attention_mask=encoder_attention_mask,
  287. output_attentions=output_attentions,
  288. )
  289. attention_output = self.output(self_outputs[0], hidden_states)
  290. outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
  291. return outputs
  292. # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->Bros
  293. class BrosIntermediate(nn.Module):
  294. def __init__(self, config):
  295. super().__init__()
  296. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  297. if isinstance(config.hidden_act, str):
  298. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  299. else:
  300. self.intermediate_act_fn = config.hidden_act
  301. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  302. hidden_states = self.dense(hidden_states)
  303. hidden_states = self.intermediate_act_fn(hidden_states)
  304. return hidden_states
  305. class BrosOutput(nn.Module):
  306. def __init__(self, config):
  307. super().__init__()
  308. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  309. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  310. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  311. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  312. hidden_states = self.dense(hidden_states)
  313. hidden_states = self.dropout(hidden_states)
  314. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  315. return hidden_states
  316. class BrosLayer(GradientCheckpointingLayer):
  317. def __init__(self, config):
  318. super().__init__()
  319. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  320. self.seq_len_dim = 1
  321. self.attention = BrosAttention(config)
  322. self.is_decoder = config.is_decoder
  323. self.add_cross_attention = config.add_cross_attention
  324. if self.add_cross_attention:
  325. if not self.is_decoder:
  326. raise Exception(f"{self} should be used as a decoder model if cross attention is added")
  327. self.crossattention = BrosAttention(config)
  328. self.intermediate = BrosIntermediate(config)
  329. self.output = BrosOutput(config)
  330. def forward(
  331. self,
  332. hidden_states: torch.Tensor,
  333. bbox_pos_emb: torch.Tensor,
  334. attention_mask: Optional[torch.FloatTensor] = None,
  335. head_mask: Optional[torch.FloatTensor] = None,
  336. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  337. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  338. output_attentions: Optional[bool] = False,
  339. ) -> tuple[torch.Tensor]:
  340. self_attention_outputs = self.attention(
  341. hidden_states,
  342. bbox_pos_emb=bbox_pos_emb,
  343. attention_mask=attention_mask,
  344. head_mask=head_mask,
  345. output_attentions=output_attentions,
  346. )
  347. attention_output = self_attention_outputs[0]
  348. # if decoder, the last output is tuple of self-attn cache
  349. if self.is_decoder:
  350. outputs = self_attention_outputs[1:-1]
  351. else:
  352. outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
  353. if self.is_decoder and encoder_hidden_states is not None:
  354. if hasattr(self, "crossattention"):
  355. raise Exception(
  356. f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
  357. )
  358. cross_attention_outputs = self.crossattention(
  359. attention_output,
  360. attention_mask=attention_mask,
  361. head_mask=head_mask,
  362. encoder_hidden_states=encoder_hidden_states,
  363. encoder_attention_mask=encoder_attention_mask,
  364. output_attentions=output_attentions,
  365. )
  366. attention_output = cross_attention_outputs[0]
  367. outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
  368. layer_output = apply_chunking_to_forward(
  369. self.feed_forward_chunk,
  370. self.chunk_size_feed_forward,
  371. self.seq_len_dim,
  372. attention_output,
  373. )
  374. outputs = (layer_output,) + outputs
  375. # if decoder, return the attn key/values as the last output
  376. if self.is_decoder:
  377. outputs = outputs + (None,)
  378. return outputs
  379. def feed_forward_chunk(self, attention_output):
  380. intermediate_output = self.intermediate(attention_output)
  381. layer_output = self.output(intermediate_output, attention_output)
  382. return layer_output
  383. class BrosEncoder(nn.Module):
  384. def __init__(self, config):
  385. super().__init__()
  386. self.config = config
  387. self.layer = nn.ModuleList([BrosLayer(config) for _ in range(config.num_hidden_layers)])
  388. @can_return_tuple
  389. def forward(
  390. self,
  391. hidden_states: torch.Tensor,
  392. bbox_pos_emb: torch.Tensor,
  393. attention_mask: Optional[torch.FloatTensor] = None,
  394. head_mask: Optional[torch.FloatTensor] = None,
  395. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  396. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  397. output_attentions: Optional[bool] = False,
  398. output_hidden_states: Optional[bool] = False,
  399. return_dict: Optional[bool] = True,
  400. ) -> Union[tuple[torch.Tensor], BaseModelOutputWithCrossAttentions]:
  401. all_hidden_states = () if output_hidden_states else None
  402. all_self_attentions = () if output_attentions else None
  403. all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
  404. for i, layer_module in enumerate(self.layer):
  405. if output_hidden_states:
  406. all_hidden_states = all_hidden_states + (hidden_states,)
  407. layer_head_mask = head_mask[i] if head_mask is not None else None
  408. layer_outputs = layer_module(
  409. hidden_states=hidden_states,
  410. bbox_pos_emb=bbox_pos_emb,
  411. attention_mask=attention_mask,
  412. head_mask=layer_head_mask,
  413. encoder_hidden_states=encoder_hidden_states,
  414. encoder_attention_mask=encoder_attention_mask,
  415. output_attentions=output_attentions,
  416. )
  417. hidden_states = layer_outputs[0]
  418. if output_attentions:
  419. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  420. if self.config.add_cross_attention:
  421. all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
  422. if output_hidden_states:
  423. all_hidden_states = all_hidden_states + (hidden_states,)
  424. return BaseModelOutputWithCrossAttentions(
  425. last_hidden_state=hidden_states,
  426. hidden_states=all_hidden_states,
  427. attentions=all_self_attentions,
  428. cross_attentions=all_cross_attentions,
  429. )
  430. # Copied from transformers.models.bert.modeling_bert.BertPooler with Bert->Bros
  431. class BrosPooler(nn.Module):
  432. def __init__(self, config):
  433. super().__init__()
  434. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  435. self.activation = nn.Tanh()
  436. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  437. # We "pool" the model by simply taking the hidden state corresponding
  438. # to the first token.
  439. first_token_tensor = hidden_states[:, 0]
  440. pooled_output = self.dense(first_token_tensor)
  441. pooled_output = self.activation(pooled_output)
  442. return pooled_output
  443. class BrosRelationExtractor(nn.Module):
  444. def __init__(self, config):
  445. super().__init__()
  446. self.n_relations = config.n_relations
  447. self.backbone_hidden_size = config.hidden_size
  448. self.head_hidden_size = config.hidden_size
  449. self.classifier_dropout_prob = config.classifier_dropout_prob
  450. self.drop = nn.Dropout(self.classifier_dropout_prob)
  451. self.query = nn.Linear(self.backbone_hidden_size, self.n_relations * self.head_hidden_size)
  452. self.key = nn.Linear(self.backbone_hidden_size, self.n_relations * self.head_hidden_size)
  453. self.dummy_node = nn.Parameter(torch.zeros(1, self.backbone_hidden_size))
  454. def forward(self, query_layer: torch.Tensor, key_layer: torch.Tensor):
  455. query_layer = self.query(self.drop(query_layer))
  456. dummy_vec = self.dummy_node.unsqueeze(0).repeat(1, key_layer.size(1), 1)
  457. key_layer = torch.cat([key_layer, dummy_vec], axis=0)
  458. key_layer = self.key(self.drop(key_layer))
  459. query_layer = query_layer.view(
  460. query_layer.size(0), query_layer.size(1), self.n_relations, self.head_hidden_size
  461. )
  462. key_layer = key_layer.view(key_layer.size(0), key_layer.size(1), self.n_relations, self.head_hidden_size)
  463. relation_score = torch.matmul(
  464. query_layer.permute(2, 1, 0, 3), key_layer.permute(2, 1, 3, 0)
  465. ) # equivalent to torch.einsum("ibnd,jbnd->nbij", (query_layer, key_layer))
  466. return relation_score
  467. @auto_docstring
  468. class BrosPreTrainedModel(PreTrainedModel):
  469. config: BrosConfig
  470. base_model_prefix = "bros"
  471. def _init_weights(self, module: nn.Module):
  472. """Initialize the weights"""
  473. std = self.config.initializer_range
  474. if isinstance(module, nn.Linear):
  475. # Slightly different from the TF version which uses truncated_normal for initialization
  476. # cf https://github.com/pytorch/pytorch/pull/5617
  477. module.weight.data.normal_(mean=0.0, std=std)
  478. if module.bias is not None:
  479. module.bias.data.zero_()
  480. elif isinstance(module, nn.Embedding):
  481. module.weight.data.normal_(mean=0.0, std=std)
  482. if module.padding_idx is not None:
  483. module.weight.data[module.padding_idx].zero_()
  484. elif isinstance(module, nn.LayerNorm):
  485. module.bias.data.zero_()
  486. module.weight.data.fill_(1.0)
  487. elif isinstance(module, BrosRelationExtractor):
  488. nn.init.normal_(module.dummy_node, std=std)
  489. @auto_docstring
  490. class BrosModel(BrosPreTrainedModel):
  491. def __init__(self, config, add_pooling_layer=True):
  492. r"""
  493. add_pooling_layer (bool, *optional*, defaults to `True`):
  494. Whether to add a pooling layer
  495. """
  496. super().__init__(config)
  497. self.config = config
  498. self.embeddings = BrosTextEmbeddings(config)
  499. self.bbox_embeddings = BrosBboxEmbeddings(config)
  500. self.encoder = BrosEncoder(config)
  501. self.pooler = BrosPooler(config) if add_pooling_layer else None
  502. self.init_weights()
  503. def get_input_embeddings(self):
  504. return self.embeddings.word_embeddings
  505. def set_input_embeddings(self, value):
  506. self.embeddings.word_embeddings = value
  507. def _prune_heads(self, heads_to_prune):
  508. """
  509. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  510. class PreTrainedModel
  511. """
  512. for layer, heads in heads_to_prune.items():
  513. self.encoder.layer[layer].attention.prune_heads(heads)
  514. @can_return_tuple
  515. @auto_docstring
  516. def forward(
  517. self,
  518. input_ids: Optional[torch.Tensor] = None,
  519. bbox: Optional[torch.Tensor] = None,
  520. attention_mask: Optional[torch.Tensor] = None,
  521. token_type_ids: Optional[torch.Tensor] = None,
  522. position_ids: Optional[torch.Tensor] = None,
  523. head_mask: Optional[torch.Tensor] = None,
  524. inputs_embeds: Optional[torch.Tensor] = None,
  525. encoder_hidden_states: Optional[torch.Tensor] = None,
  526. encoder_attention_mask: Optional[torch.Tensor] = None,
  527. output_attentions: Optional[bool] = None,
  528. output_hidden_states: Optional[bool] = None,
  529. return_dict: Optional[bool] = None,
  530. ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
  531. r"""
  532. bbox ('torch.FloatTensor' of shape '(batch_size, num_boxes, 4)'):
  533. Bounding box coordinates for each token in the input sequence. Each bounding box is a list of four values
  534. (x1, y1, x2, y2), where (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner of the
  535. bounding box.
  536. Examples:
  537. ```python
  538. >>> import torch
  539. >>> from transformers import BrosProcessor, BrosModel
  540. >>> processor = BrosProcessor.from_pretrained("jinho8345/bros-base-uncased")
  541. >>> model = BrosModel.from_pretrained("jinho8345/bros-base-uncased")
  542. >>> encoding = processor("Hello, my dog is cute", add_special_tokens=False, return_tensors="pt")
  543. >>> bbox = torch.tensor([[[0, 0, 1, 1]]]).repeat(1, encoding["input_ids"].shape[-1], 1)
  544. >>> encoding["bbox"] = bbox
  545. >>> outputs = model(**encoding)
  546. >>> last_hidden_states = outputs.last_hidden_state
  547. ```"""
  548. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  549. output_hidden_states = (
  550. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  551. )
  552. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  553. if input_ids is not None and inputs_embeds is not None:
  554. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  555. elif input_ids is not None:
  556. input_shape = input_ids.size()
  557. elif inputs_embeds is not None:
  558. input_shape = inputs_embeds.size()[:-1]
  559. else:
  560. raise ValueError("You have to specify either input_ids or inputs_embeds")
  561. if bbox is None:
  562. raise ValueError("You have to specify bbox")
  563. batch_size, seq_length = input_shape
  564. device = input_ids.device if input_ids is not None else inputs_embeds.device
  565. if attention_mask is None:
  566. attention_mask = torch.ones(input_shape, device=device)
  567. if token_type_ids is None:
  568. if hasattr(self.embeddings, "token_type_ids"):
  569. buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
  570. buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
  571. token_type_ids = buffered_token_type_ids_expanded
  572. else:
  573. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
  574. # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
  575. # ourselves in which case we just need to make it broadcastable to all heads.
  576. extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
  577. # If a 2D or 3D attention mask is provided for the cross-attention
  578. # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
  579. if self.config.is_decoder and encoder_hidden_states is not None:
  580. encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
  581. encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
  582. if encoder_attention_mask is None:
  583. encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
  584. encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
  585. else:
  586. encoder_extended_attention_mask = None
  587. # Prepare head mask if needed
  588. # 1.0 in head_mask indicate we keep the head
  589. # attention_probs has shape bsz x n_heads x N x N
  590. # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
  591. # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
  592. head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
  593. embedding_output = self.embeddings(
  594. input_ids=input_ids,
  595. position_ids=position_ids,
  596. token_type_ids=token_type_ids,
  597. inputs_embeds=inputs_embeds,
  598. )
  599. # if bbox has 2 points (4 float tensors) per token, convert it to 4 points (8 float tensors) per token
  600. if bbox.shape[-1] == 4:
  601. bbox = bbox[:, :, [0, 1, 2, 1, 2, 3, 0, 3]]
  602. scaled_bbox = bbox * self.config.bbox_scale
  603. bbox_position_embeddings = self.bbox_embeddings(scaled_bbox)
  604. encoder_outputs = self.encoder(
  605. embedding_output,
  606. bbox_pos_emb=bbox_position_embeddings,
  607. attention_mask=extended_attention_mask,
  608. head_mask=head_mask,
  609. encoder_hidden_states=encoder_hidden_states,
  610. encoder_attention_mask=encoder_extended_attention_mask,
  611. output_attentions=output_attentions,
  612. output_hidden_states=output_hidden_states,
  613. return_dict=True,
  614. )
  615. sequence_output = encoder_outputs[0]
  616. pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
  617. return BaseModelOutputWithPoolingAndCrossAttentions(
  618. last_hidden_state=sequence_output,
  619. pooler_output=pooled_output,
  620. hidden_states=encoder_outputs.hidden_states,
  621. attentions=encoder_outputs.attentions,
  622. cross_attentions=encoder_outputs.cross_attentions,
  623. )
  624. @auto_docstring
  625. class BrosForTokenClassification(BrosPreTrainedModel):
  626. _keys_to_ignore_on_load_unexpected = [r"pooler"]
  627. def __init__(self, config):
  628. super().__init__(config)
  629. self.num_labels = config.num_labels
  630. self.bros = BrosModel(config)
  631. classifier_dropout = (
  632. config.classifier_dropout if hasattr(config, "classifier_dropout") else config.hidden_dropout_prob
  633. )
  634. self.dropout = nn.Dropout(classifier_dropout)
  635. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  636. self.init_weights()
  637. @can_return_tuple
  638. @auto_docstring
  639. def forward(
  640. self,
  641. input_ids: Optional[torch.Tensor] = None,
  642. bbox: Optional[torch.Tensor] = None,
  643. attention_mask: Optional[torch.Tensor] = None,
  644. bbox_first_token_mask: Optional[torch.Tensor] = None,
  645. token_type_ids: Optional[torch.Tensor] = None,
  646. position_ids: Optional[torch.Tensor] = None,
  647. head_mask: Optional[torch.Tensor] = None,
  648. inputs_embeds: Optional[torch.Tensor] = None,
  649. labels: Optional[torch.Tensor] = None,
  650. output_attentions: Optional[bool] = None,
  651. output_hidden_states: Optional[bool] = None,
  652. return_dict: Optional[bool] = None,
  653. ) -> Union[tuple[torch.Tensor], TokenClassifierOutput]:
  654. r"""
  655. bbox ('torch.FloatTensor' of shape '(batch_size, num_boxes, 4)'):
  656. Bounding box coordinates for each token in the input sequence. Each bounding box is a list of four values
  657. (x1, y1, x2, y2), where (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner of the
  658. bounding box.
  659. bbox_first_token_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
  660. Mask to indicate the first token of each bounding box. Mask values selected in `[0, 1]`:
  661. - 1 for tokens that are **not masked**,
  662. - 0 for tokens that are **masked**.
  663. Examples:
  664. ```python
  665. >>> import torch
  666. >>> from transformers import BrosProcessor, BrosForTokenClassification
  667. >>> processor = BrosProcessor.from_pretrained("jinho8345/bros-base-uncased")
  668. >>> model = BrosForTokenClassification.from_pretrained("jinho8345/bros-base-uncased")
  669. >>> encoding = processor("Hello, my dog is cute", add_special_tokens=False, return_tensors="pt")
  670. >>> bbox = torch.tensor([[[0, 0, 1, 1]]]).repeat(1, encoding["input_ids"].shape[-1], 1)
  671. >>> encoding["bbox"] = bbox
  672. >>> outputs = model(**encoding)
  673. ```"""
  674. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  675. outputs = self.bros(
  676. input_ids,
  677. bbox=bbox,
  678. attention_mask=attention_mask,
  679. token_type_ids=token_type_ids,
  680. position_ids=position_ids,
  681. head_mask=head_mask,
  682. inputs_embeds=inputs_embeds,
  683. output_attentions=output_attentions,
  684. output_hidden_states=output_hidden_states,
  685. return_dict=True,
  686. )
  687. sequence_output = outputs[0]
  688. sequence_output = self.dropout(sequence_output)
  689. logits = self.classifier(sequence_output)
  690. loss = None
  691. if labels is not None:
  692. loss_fct = CrossEntropyLoss()
  693. if bbox_first_token_mask is not None:
  694. bbox_first_token_mask = bbox_first_token_mask.view(-1)
  695. loss = loss_fct(
  696. logits.view(-1, self.num_labels)[bbox_first_token_mask], labels.view(-1)[bbox_first_token_mask]
  697. )
  698. else:
  699. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  700. return TokenClassifierOutput(
  701. loss=loss,
  702. logits=logits,
  703. hidden_states=outputs.hidden_states,
  704. attentions=outputs.attentions,
  705. )
  706. @auto_docstring(
  707. custom_intro="""
  708. Bros Model with a token classification head on top (initial_token_layers and subsequent_token_layer on top of the
  709. hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. The initial_token_classifier is used to
  710. predict the first token of each entity, and the subsequent_token_classifier is used to predict the subsequent
  711. tokens within an entity. Compared to BrosForTokenClassification, this model is more robust to serialization errors
  712. since it predicts next token from one token.
  713. """
  714. )
  715. class BrosSpadeEEForTokenClassification(BrosPreTrainedModel):
  716. _keys_to_ignore_on_load_unexpected = [r"pooler"]
  717. def __init__(self, config):
  718. super().__init__(config)
  719. self.config = config
  720. self.num_labels = config.num_labels
  721. self.n_relations = config.n_relations
  722. self.backbone_hidden_size = config.hidden_size
  723. self.bros = BrosModel(config)
  724. classifier_dropout = (
  725. config.classifier_dropout if hasattr(config, "classifier_dropout") else config.hidden_dropout_prob
  726. )
  727. # Initial token classification for Entity Extraction (NER)
  728. self.initial_token_classifier = nn.Sequential(
  729. nn.Dropout(classifier_dropout),
  730. nn.Linear(config.hidden_size, config.hidden_size),
  731. nn.Dropout(classifier_dropout),
  732. nn.Linear(config.hidden_size, config.num_labels),
  733. )
  734. # Subsequent token classification for Entity Extraction (NER)
  735. self.subsequent_token_classifier = BrosRelationExtractor(config)
  736. self.init_weights()
  737. @can_return_tuple
  738. @auto_docstring
  739. def forward(
  740. self,
  741. input_ids: Optional[torch.Tensor] = None,
  742. bbox: Optional[torch.Tensor] = None,
  743. attention_mask: Optional[torch.Tensor] = None,
  744. bbox_first_token_mask: Optional[torch.Tensor] = None,
  745. token_type_ids: Optional[torch.Tensor] = None,
  746. position_ids: Optional[torch.Tensor] = None,
  747. head_mask: Optional[torch.Tensor] = None,
  748. inputs_embeds: Optional[torch.Tensor] = None,
  749. initial_token_labels: Optional[torch.Tensor] = None,
  750. subsequent_token_labels: Optional[torch.Tensor] = None,
  751. output_attentions: Optional[bool] = None,
  752. output_hidden_states: Optional[bool] = None,
  753. return_dict: Optional[bool] = None,
  754. ) -> Union[tuple[torch.Tensor], BrosSpadeOutput]:
  755. r"""
  756. bbox ('torch.FloatTensor' of shape '(batch_size, num_boxes, 4)'):
  757. Bounding box coordinates for each token in the input sequence. Each bounding box is a list of four values
  758. (x1, y1, x2, y2), where (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner of the
  759. bounding box.
  760. bbox_first_token_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
  761. Mask to indicate the first token of each bounding box. Mask values selected in `[0, 1]`:
  762. - 1 for tokens that are **not masked**,
  763. - 0 for tokens that are **masked**.
  764. initial_token_labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  765. Labels for the initial token classification.
  766. subsequent_token_labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  767. Labels for the subsequent token classification.
  768. Examples:
  769. ```python
  770. >>> import torch
  771. >>> from transformers import BrosProcessor, BrosSpadeEEForTokenClassification
  772. >>> processor = BrosProcessor.from_pretrained("jinho8345/bros-base-uncased")
  773. >>> model = BrosSpadeEEForTokenClassification.from_pretrained("jinho8345/bros-base-uncased")
  774. >>> encoding = processor("Hello, my dog is cute", add_special_tokens=False, return_tensors="pt")
  775. >>> bbox = torch.tensor([[[0, 0, 1, 1]]]).repeat(1, encoding["input_ids"].shape[-1], 1)
  776. >>> encoding["bbox"] = bbox
  777. >>> outputs = model(**encoding)
  778. ```"""
  779. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  780. outputs = self.bros(
  781. input_ids=input_ids,
  782. bbox=bbox,
  783. attention_mask=attention_mask,
  784. token_type_ids=token_type_ids,
  785. position_ids=position_ids,
  786. head_mask=head_mask,
  787. inputs_embeds=inputs_embeds,
  788. output_attentions=output_attentions,
  789. output_hidden_states=output_hidden_states,
  790. return_dict=True,
  791. )
  792. last_hidden_states = outputs[0]
  793. last_hidden_states = last_hidden_states.transpose(0, 1).contiguous()
  794. initial_token_logits = self.initial_token_classifier(last_hidden_states).transpose(0, 1).contiguous()
  795. subsequent_token_logits = self.subsequent_token_classifier(last_hidden_states, last_hidden_states).squeeze(0)
  796. # make subsequent token (sequence token classification) mask
  797. inv_attention_mask = 1 - attention_mask
  798. batch_size, max_seq_length = inv_attention_mask.shape
  799. device = inv_attention_mask.device
  800. invalid_token_mask = torch.cat([inv_attention_mask, torch.zeros([batch_size, 1]).to(device)], axis=1).bool()
  801. subsequent_token_logits = subsequent_token_logits.masked_fill(
  802. invalid_token_mask[:, None, :], torch.finfo(subsequent_token_logits.dtype).min
  803. )
  804. self_token_mask = torch.eye(max_seq_length, max_seq_length + 1).to(device=device, dtype=torch.bool)
  805. subsequent_token_logits = subsequent_token_logits.masked_fill(
  806. self_token_mask[None, :, :], torch.finfo(subsequent_token_logits.dtype).min
  807. )
  808. subsequent_token_mask = attention_mask.view(-1).bool()
  809. loss = None
  810. if initial_token_labels is not None and subsequent_token_labels is not None:
  811. loss_fct = CrossEntropyLoss()
  812. # get initial token loss
  813. initial_token_labels = initial_token_labels.view(-1)
  814. if bbox_first_token_mask is not None:
  815. bbox_first_token_mask = bbox_first_token_mask.view(-1)
  816. initial_token_loss = loss_fct(
  817. initial_token_logits.view(-1, self.num_labels)[bbox_first_token_mask],
  818. initial_token_labels[bbox_first_token_mask],
  819. )
  820. else:
  821. initial_token_loss = loss_fct(initial_token_logits.view(-1, self.num_labels), initial_token_labels)
  822. subsequent_token_labels = subsequent_token_labels.view(-1)
  823. subsequent_token_loss = loss_fct(
  824. subsequent_token_logits.view(-1, max_seq_length + 1)[subsequent_token_mask],
  825. subsequent_token_labels[subsequent_token_mask],
  826. )
  827. loss = initial_token_loss + subsequent_token_loss
  828. return BrosSpadeOutput(
  829. loss=loss,
  830. initial_token_logits=initial_token_logits,
  831. subsequent_token_logits=subsequent_token_logits,
  832. hidden_states=outputs.hidden_states,
  833. attentions=outputs.attentions,
  834. )
  835. @auto_docstring(
  836. custom_intro="""
  837. Bros Model with a token classification head on top (a entity_linker layer on top of the hidden-states output) e.g.
  838. for Entity-Linking. The entity_linker is used to predict intra-entity links (one entity to another entity).
  839. """
  840. )
  841. class BrosSpadeELForTokenClassification(BrosPreTrainedModel):
  842. _keys_to_ignore_on_load_unexpected = [r"pooler"]
  843. def __init__(self, config):
  844. super().__init__(config)
  845. self.config = config
  846. self.num_labels = config.num_labels
  847. self.n_relations = config.n_relations
  848. self.backbone_hidden_size = config.hidden_size
  849. self.bros = BrosModel(config)
  850. (config.classifier_dropout if hasattr(config, "classifier_dropout") else config.hidden_dropout_prob)
  851. self.entity_linker = BrosRelationExtractor(config)
  852. self.init_weights()
  853. @can_return_tuple
  854. @auto_docstring
  855. def forward(
  856. self,
  857. input_ids: Optional[torch.Tensor] = None,
  858. bbox: Optional[torch.Tensor] = None,
  859. attention_mask: Optional[torch.Tensor] = None,
  860. bbox_first_token_mask: Optional[torch.Tensor] = None,
  861. token_type_ids: Optional[torch.Tensor] = None,
  862. position_ids: Optional[torch.Tensor] = None,
  863. head_mask: Optional[torch.Tensor] = None,
  864. inputs_embeds: Optional[torch.Tensor] = None,
  865. labels: Optional[torch.Tensor] = None,
  866. output_attentions: Optional[bool] = None,
  867. output_hidden_states: Optional[bool] = None,
  868. return_dict: Optional[bool] = None,
  869. ) -> Union[tuple[torch.Tensor], TokenClassifierOutput]:
  870. r"""
  871. bbox ('torch.FloatTensor' of shape '(batch_size, num_boxes, 4)'):
  872. Bounding box coordinates for each token in the input sequence. Each bounding box is a list of four values
  873. (x1, y1, x2, y2), where (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner of the
  874. bounding box.
  875. bbox_first_token_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
  876. Mask to indicate the first token of each bounding box. Mask values selected in `[0, 1]`:
  877. - 1 for tokens that are **not masked**,
  878. - 0 for tokens that are **masked**.
  879. Examples:
  880. ```python
  881. >>> import torch
  882. >>> from transformers import BrosProcessor, BrosSpadeELForTokenClassification
  883. >>> processor = BrosProcessor.from_pretrained("jinho8345/bros-base-uncased")
  884. >>> model = BrosSpadeELForTokenClassification.from_pretrained("jinho8345/bros-base-uncased")
  885. >>> encoding = processor("Hello, my dog is cute", add_special_tokens=False, return_tensors="pt")
  886. >>> bbox = torch.tensor([[[0, 0, 1, 1]]]).repeat(1, encoding["input_ids"].shape[-1], 1)
  887. >>> encoding["bbox"] = bbox
  888. >>> outputs = model(**encoding)
  889. ```"""
  890. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  891. outputs = self.bros(
  892. input_ids=input_ids,
  893. bbox=bbox,
  894. attention_mask=attention_mask,
  895. token_type_ids=token_type_ids,
  896. position_ids=position_ids,
  897. head_mask=head_mask,
  898. inputs_embeds=inputs_embeds,
  899. output_attentions=output_attentions,
  900. output_hidden_states=output_hidden_states,
  901. return_dict=True,
  902. )
  903. last_hidden_states = outputs[0]
  904. last_hidden_states = last_hidden_states.transpose(0, 1).contiguous()
  905. logits = self.entity_linker(last_hidden_states, last_hidden_states).squeeze(0)
  906. loss = None
  907. if labels is not None:
  908. loss_fct = CrossEntropyLoss()
  909. batch_size, max_seq_length = attention_mask.shape
  910. device = attention_mask.device
  911. self_token_mask = torch.eye(max_seq_length, max_seq_length + 1).to(device=device, dtype=torch.bool)
  912. mask = bbox_first_token_mask.view(-1)
  913. bbox_first_token_mask = torch.cat(
  914. [
  915. ~bbox_first_token_mask,
  916. torch.zeros([batch_size, 1], dtype=torch.bool, device=device),
  917. ],
  918. axis=1,
  919. )
  920. logits = logits.masked_fill(bbox_first_token_mask[:, None, :], torch.finfo(logits.dtype).min)
  921. logits = logits.masked_fill(self_token_mask[None, :, :], torch.finfo(logits.dtype).min)
  922. loss = loss_fct(logits.view(-1, max_seq_length + 1)[mask], labels.view(-1)[mask])
  923. return TokenClassifierOutput(
  924. loss=loss,
  925. logits=logits,
  926. hidden_states=outputs.hidden_states,
  927. attentions=outputs.attentions,
  928. )
  929. __all__ = [
  930. "BrosPreTrainedModel",
  931. "BrosModel",
  932. "BrosForTokenClassification",
  933. "BrosSpadeEEForTokenClassification",
  934. "BrosSpadeELForTokenClassification",
  935. ]