modeling_luke.py 97 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179
  1. # coding=utf-8
  2. # Copyright Studio Ousia and The HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch LUKE model."""
  16. import math
  17. from dataclasses import dataclass
  18. from typing import Optional, Union
  19. import torch
  20. from torch import nn
  21. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  22. from ...activations import ACT2FN, gelu
  23. from ...modeling_layers import GradientCheckpointingLayer
  24. from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
  25. from ...modeling_utils import PreTrainedModel
  26. from ...pytorch_utils import apply_chunking_to_forward
  27. from ...utils import ModelOutput, auto_docstring, logging
  28. from .configuration_luke import LukeConfig
  29. logger = logging.get_logger(__name__)
  30. @dataclass
  31. @auto_docstring(
  32. custom_intro="""
  33. Base class for outputs of the LUKE model.
  34. """
  35. )
  36. class BaseLukeModelOutputWithPooling(BaseModelOutputWithPooling):
  37. r"""
  38. pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
  39. Last layer hidden-state of the first token of the sequence (classification token) further processed by a
  40. Linear layer and a Tanh activation function.
  41. entity_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, entity_length, hidden_size)`):
  42. Sequence of entity hidden-states at the output of the last layer of the model.
  43. entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  44. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
  45. shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each
  46. layer plus the initial entity embedding outputs.
  47. """
  48. entity_last_hidden_state: Optional[torch.FloatTensor] = None
  49. entity_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  50. @dataclass
  51. @auto_docstring(
  52. custom_intro="""
  53. Base class for model's outputs, with potential hidden states and attentions.
  54. """
  55. )
  56. class BaseLukeModelOutput(BaseModelOutput):
  57. r"""
  58. entity_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, entity_length, hidden_size)`):
  59. Sequence of entity hidden-states at the output of the last layer of the model.
  60. entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  61. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
  62. shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each
  63. layer plus the initial entity embedding outputs.
  64. """
  65. entity_last_hidden_state: Optional[torch.FloatTensor] = None
  66. entity_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  67. @dataclass
  68. @auto_docstring(
  69. custom_intro="""
  70. Base class for model's outputs, with potential hidden states and attentions.
  71. """
  72. )
  73. class LukeMaskedLMOutput(ModelOutput):
  74. r"""
  75. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  76. The sum of masked language modeling (MLM) loss and entity prediction loss.
  77. mlm_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  78. Masked language modeling (MLM) loss.
  79. mep_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  80. Masked entity prediction (MEP) loss.
  81. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  82. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  83. entity_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  84. Prediction scores of the entity prediction head (scores for each entity vocabulary token before SoftMax).
  85. entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  86. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
  87. shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each
  88. layer plus the initial entity embedding outputs.
  89. """
  90. loss: Optional[torch.FloatTensor] = None
  91. mlm_loss: Optional[torch.FloatTensor] = None
  92. mep_loss: Optional[torch.FloatTensor] = None
  93. logits: Optional[torch.FloatTensor] = None
  94. entity_logits: Optional[torch.FloatTensor] = None
  95. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  96. entity_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  97. attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  98. @dataclass
  99. @auto_docstring(
  100. custom_intro="""
  101. Outputs of entity classification models.
  102. """
  103. )
  104. class EntityClassificationOutput(ModelOutput):
  105. r"""
  106. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  107. Classification loss.
  108. logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
  109. Classification scores (before SoftMax).
  110. entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  111. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
  112. shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each
  113. layer plus the initial entity embedding outputs.
  114. """
  115. loss: Optional[torch.FloatTensor] = None
  116. logits: Optional[torch.FloatTensor] = None
  117. hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  118. entity_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  119. attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  120. @dataclass
  121. @auto_docstring(
  122. custom_intro="""
  123. Outputs of entity pair classification models.
  124. """
  125. )
  126. class EntityPairClassificationOutput(ModelOutput):
  127. r"""
  128. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  129. Classification loss.
  130. logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
  131. Classification scores (before SoftMax).
  132. entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  133. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
  134. shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each
  135. layer plus the initial entity embedding outputs.
  136. """
  137. loss: Optional[torch.FloatTensor] = None
  138. logits: Optional[torch.FloatTensor] = None
  139. hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  140. entity_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  141. attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  142. @dataclass
  143. @auto_docstring(
  144. custom_intro="""
  145. Outputs of entity span classification models.
  146. """
  147. )
  148. class EntitySpanClassificationOutput(ModelOutput):
  149. r"""
  150. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  151. Classification loss.
  152. logits (`torch.FloatTensor` of shape `(batch_size, entity_length, config.num_labels)`):
  153. Classification scores (before SoftMax).
  154. entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  155. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
  156. shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each
  157. layer plus the initial entity embedding outputs.
  158. """
  159. loss: Optional[torch.FloatTensor] = None
  160. logits: Optional[torch.FloatTensor] = None
  161. hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  162. entity_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  163. attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  164. @dataclass
  165. @auto_docstring(
  166. custom_intro="""
  167. Outputs of sentence classification models.
  168. """
  169. )
  170. class LukeSequenceClassifierOutput(ModelOutput):
  171. r"""
  172. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  173. Classification (or regression if config.num_labels==1) loss.
  174. logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
  175. Classification (or regression if config.num_labels==1) scores (before SoftMax).
  176. entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  177. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
  178. shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each
  179. layer plus the initial entity embedding outputs.
  180. """
  181. loss: Optional[torch.FloatTensor] = None
  182. logits: Optional[torch.FloatTensor] = None
  183. hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  184. entity_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  185. attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  186. @dataclass
  187. @auto_docstring(
  188. custom_intro="""
  189. Base class for outputs of token classification models.
  190. """
  191. )
  192. class LukeTokenClassifierOutput(ModelOutput):
  193. r"""
  194. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  195. Classification loss.
  196. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`):
  197. Classification scores (before SoftMax).
  198. entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  199. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
  200. shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each
  201. layer plus the initial entity embedding outputs.
  202. """
  203. loss: Optional[torch.FloatTensor] = None
  204. logits: Optional[torch.FloatTensor] = None
  205. hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  206. entity_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  207. attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  208. @dataclass
  209. @auto_docstring(
  210. custom_intro="""
  211. Outputs of question answering models.
  212. """
  213. )
  214. class LukeQuestionAnsweringModelOutput(ModelOutput):
  215. r"""
  216. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  217. Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
  218. entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  219. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
  220. shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each
  221. layer plus the initial entity embedding outputs.
  222. """
  223. loss: Optional[torch.FloatTensor] = None
  224. start_logits: Optional[torch.FloatTensor] = None
  225. end_logits: Optional[torch.FloatTensor] = None
  226. hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  227. entity_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  228. attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  229. @dataclass
  230. @auto_docstring(
  231. custom_intro="""
  232. Outputs of multiple choice models.
  233. """
  234. )
  235. class LukeMultipleChoiceModelOutput(ModelOutput):
  236. r"""
  237. loss (`torch.FloatTensor` of shape *(1,)*, *optional*, returned when `labels` is provided):
  238. Classification loss.
  239. logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`):
  240. *num_choices* is the second dimension of the input tensors. (see *input_ids* above).
  241. Classification scores (before SoftMax).
  242. entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  243. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
  244. shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each
  245. layer plus the initial entity embedding outputs.
  246. """
  247. loss: Optional[torch.FloatTensor] = None
  248. logits: Optional[torch.FloatTensor] = None
  249. hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  250. entity_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  251. attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  252. class LukeEmbeddings(nn.Module):
  253. """
  254. Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
  255. """
  256. def __init__(self, config):
  257. super().__init__()
  258. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  259. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
  260. self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
  261. # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
  262. # any TensorFlow checkpoint file
  263. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  264. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  265. # End copy
  266. self.padding_idx = config.pad_token_id
  267. self.position_embeddings = nn.Embedding(
  268. config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
  269. )
  270. def forward(
  271. self,
  272. input_ids=None,
  273. token_type_ids=None,
  274. position_ids=None,
  275. inputs_embeds=None,
  276. ):
  277. if position_ids is None:
  278. if input_ids is not None:
  279. # Create the position ids from the input token ids. Any padded tokens remain padded.
  280. position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx).to(input_ids.device)
  281. else:
  282. position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
  283. if input_ids is not None:
  284. input_shape = input_ids.size()
  285. else:
  286. input_shape = inputs_embeds.size()[:-1]
  287. if token_type_ids is None:
  288. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
  289. if inputs_embeds is None:
  290. inputs_embeds = self.word_embeddings(input_ids)
  291. position_embeddings = self.position_embeddings(position_ids)
  292. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  293. embeddings = inputs_embeds + position_embeddings + token_type_embeddings
  294. embeddings = self.LayerNorm(embeddings)
  295. embeddings = self.dropout(embeddings)
  296. return embeddings
  297. def create_position_ids_from_inputs_embeds(self, inputs_embeds):
  298. """
  299. We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
  300. Args:
  301. inputs_embeds: torch.Tensor
  302. Returns: torch.Tensor
  303. """
  304. input_shape = inputs_embeds.size()[:-1]
  305. sequence_length = input_shape[1]
  306. position_ids = torch.arange(
  307. self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
  308. )
  309. return position_ids.unsqueeze(0).expand(input_shape)
  310. class LukeEntityEmbeddings(nn.Module):
  311. def __init__(self, config: LukeConfig):
  312. super().__init__()
  313. self.config = config
  314. self.entity_embeddings = nn.Embedding(config.entity_vocab_size, config.entity_emb_size, padding_idx=0)
  315. if config.entity_emb_size != config.hidden_size:
  316. self.entity_embedding_dense = nn.Linear(config.entity_emb_size, config.hidden_size, bias=False)
  317. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
  318. self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
  319. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  320. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  321. def forward(
  322. self,
  323. entity_ids: torch.LongTensor,
  324. position_ids: torch.LongTensor,
  325. token_type_ids: Optional[torch.LongTensor] = None,
  326. ):
  327. if token_type_ids is None:
  328. token_type_ids = torch.zeros_like(entity_ids)
  329. entity_embeddings = self.entity_embeddings(entity_ids)
  330. if self.config.entity_emb_size != self.config.hidden_size:
  331. entity_embeddings = self.entity_embedding_dense(entity_embeddings)
  332. position_embeddings = self.position_embeddings(position_ids.clamp(min=0))
  333. position_embedding_mask = (position_ids != -1).type_as(position_embeddings).unsqueeze(-1)
  334. position_embeddings = position_embeddings * position_embedding_mask
  335. position_embeddings = torch.sum(position_embeddings, dim=-2)
  336. position_embeddings = position_embeddings / position_embedding_mask.sum(dim=-2).clamp(min=1e-7)
  337. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  338. embeddings = entity_embeddings + position_embeddings + token_type_embeddings
  339. embeddings = self.LayerNorm(embeddings)
  340. embeddings = self.dropout(embeddings)
  341. return embeddings
  342. class LukeSelfAttention(nn.Module):
  343. def __init__(self, config):
  344. super().__init__()
  345. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  346. raise ValueError(
  347. f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
  348. f"heads {config.num_attention_heads}."
  349. )
  350. self.num_attention_heads = config.num_attention_heads
  351. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  352. self.all_head_size = self.num_attention_heads * self.attention_head_size
  353. self.use_entity_aware_attention = config.use_entity_aware_attention
  354. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  355. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  356. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  357. if self.use_entity_aware_attention:
  358. self.w2e_query = nn.Linear(config.hidden_size, self.all_head_size)
  359. self.e2w_query = nn.Linear(config.hidden_size, self.all_head_size)
  360. self.e2e_query = nn.Linear(config.hidden_size, self.all_head_size)
  361. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  362. def transpose_for_scores(self, x):
  363. new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
  364. x = x.view(*new_x_shape)
  365. return x.permute(0, 2, 1, 3)
  366. def forward(
  367. self,
  368. word_hidden_states,
  369. entity_hidden_states,
  370. attention_mask=None,
  371. head_mask=None,
  372. output_attentions=False,
  373. ):
  374. word_size = word_hidden_states.size(1)
  375. if entity_hidden_states is None:
  376. concat_hidden_states = word_hidden_states
  377. else:
  378. concat_hidden_states = torch.cat([word_hidden_states, entity_hidden_states], dim=1)
  379. key_layer = self.transpose_for_scores(self.key(concat_hidden_states))
  380. value_layer = self.transpose_for_scores(self.value(concat_hidden_states))
  381. if self.use_entity_aware_attention and entity_hidden_states is not None:
  382. # compute query vectors using word-word (w2w), word-entity (w2e), entity-word (e2w), entity-entity (e2e)
  383. # query layers
  384. w2w_query_layer = self.transpose_for_scores(self.query(word_hidden_states))
  385. w2e_query_layer = self.transpose_for_scores(self.w2e_query(word_hidden_states))
  386. e2w_query_layer = self.transpose_for_scores(self.e2w_query(entity_hidden_states))
  387. e2e_query_layer = self.transpose_for_scores(self.e2e_query(entity_hidden_states))
  388. # compute w2w, w2e, e2w, and e2e key vectors used with the query vectors computed above
  389. w2w_key_layer = key_layer[:, :, :word_size, :]
  390. e2w_key_layer = key_layer[:, :, :word_size, :]
  391. w2e_key_layer = key_layer[:, :, word_size:, :]
  392. e2e_key_layer = key_layer[:, :, word_size:, :]
  393. # compute attention scores based on the dot product between the query and key vectors
  394. w2w_attention_scores = torch.matmul(w2w_query_layer, w2w_key_layer.transpose(-1, -2))
  395. w2e_attention_scores = torch.matmul(w2e_query_layer, w2e_key_layer.transpose(-1, -2))
  396. e2w_attention_scores = torch.matmul(e2w_query_layer, e2w_key_layer.transpose(-1, -2))
  397. e2e_attention_scores = torch.matmul(e2e_query_layer, e2e_key_layer.transpose(-1, -2))
  398. # combine attention scores to create the final attention score matrix
  399. word_attention_scores = torch.cat([w2w_attention_scores, w2e_attention_scores], dim=3)
  400. entity_attention_scores = torch.cat([e2w_attention_scores, e2e_attention_scores], dim=3)
  401. attention_scores = torch.cat([word_attention_scores, entity_attention_scores], dim=2)
  402. else:
  403. query_layer = self.transpose_for_scores(self.query(concat_hidden_states))
  404. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  405. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  406. if attention_mask is not None:
  407. # Apply the attention mask is (precomputed for all layers in LukeModel forward() function)
  408. attention_scores = attention_scores + attention_mask
  409. # Normalize the attention scores to probabilities.
  410. attention_probs = nn.functional.softmax(attention_scores, dim=-1)
  411. # This is actually dropping out entire tokens to attend to, which might
  412. # seem a bit unusual, but is taken from the original Transformer paper.
  413. attention_probs = self.dropout(attention_probs)
  414. # Mask heads if we want to
  415. if head_mask is not None:
  416. attention_probs = attention_probs * head_mask
  417. context_layer = torch.matmul(attention_probs, value_layer)
  418. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  419. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  420. context_layer = context_layer.view(*new_context_layer_shape)
  421. output_word_hidden_states = context_layer[:, :word_size, :]
  422. if entity_hidden_states is None:
  423. output_entity_hidden_states = None
  424. else:
  425. output_entity_hidden_states = context_layer[:, word_size:, :]
  426. if output_attentions:
  427. outputs = (output_word_hidden_states, output_entity_hidden_states, attention_probs)
  428. else:
  429. outputs = (output_word_hidden_states, output_entity_hidden_states)
  430. return outputs
  431. # Copied from transformers.models.bert.modeling_bert.BertSelfOutput
  432. class LukeSelfOutput(nn.Module):
  433. def __init__(self, config):
  434. super().__init__()
  435. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  436. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  437. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  438. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  439. hidden_states = self.dense(hidden_states)
  440. hidden_states = self.dropout(hidden_states)
  441. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  442. return hidden_states
  443. class LukeAttention(nn.Module):
  444. def __init__(self, config):
  445. super().__init__()
  446. self.self = LukeSelfAttention(config)
  447. self.output = LukeSelfOutput(config)
  448. self.pruned_heads = set()
  449. def prune_heads(self, heads):
  450. raise NotImplementedError("LUKE does not support the pruning of attention heads")
  451. def forward(
  452. self,
  453. word_hidden_states,
  454. entity_hidden_states,
  455. attention_mask=None,
  456. head_mask=None,
  457. output_attentions=False,
  458. ):
  459. word_size = word_hidden_states.size(1)
  460. self_outputs = self.self(
  461. word_hidden_states,
  462. entity_hidden_states,
  463. attention_mask,
  464. head_mask,
  465. output_attentions,
  466. )
  467. if entity_hidden_states is None:
  468. concat_self_outputs = self_outputs[0]
  469. concat_hidden_states = word_hidden_states
  470. else:
  471. concat_self_outputs = torch.cat(self_outputs[:2], dim=1)
  472. concat_hidden_states = torch.cat([word_hidden_states, entity_hidden_states], dim=1)
  473. attention_output = self.output(concat_self_outputs, concat_hidden_states)
  474. word_attention_output = attention_output[:, :word_size, :]
  475. if entity_hidden_states is None:
  476. entity_attention_output = None
  477. else:
  478. entity_attention_output = attention_output[:, word_size:, :]
  479. # add attentions if we output them
  480. outputs = (word_attention_output, entity_attention_output) + self_outputs[2:]
  481. return outputs
  482. # Copied from transformers.models.bert.modeling_bert.BertIntermediate
  483. class LukeIntermediate(nn.Module):
  484. def __init__(self, config):
  485. super().__init__()
  486. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  487. if isinstance(config.hidden_act, str):
  488. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  489. else:
  490. self.intermediate_act_fn = config.hidden_act
  491. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  492. hidden_states = self.dense(hidden_states)
  493. hidden_states = self.intermediate_act_fn(hidden_states)
  494. return hidden_states
  495. # Copied from transformers.models.bert.modeling_bert.BertOutput
  496. class LukeOutput(nn.Module):
  497. def __init__(self, config):
  498. super().__init__()
  499. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  500. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  501. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  502. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  503. hidden_states = self.dense(hidden_states)
  504. hidden_states = self.dropout(hidden_states)
  505. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  506. return hidden_states
  507. class LukeLayer(GradientCheckpointingLayer):
  508. def __init__(self, config):
  509. super().__init__()
  510. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  511. self.seq_len_dim = 1
  512. self.attention = LukeAttention(config)
  513. self.intermediate = LukeIntermediate(config)
  514. self.output = LukeOutput(config)
  515. def forward(
  516. self,
  517. word_hidden_states,
  518. entity_hidden_states,
  519. attention_mask=None,
  520. head_mask=None,
  521. output_attentions=False,
  522. ):
  523. word_size = word_hidden_states.size(1)
  524. self_attention_outputs = self.attention(
  525. word_hidden_states,
  526. entity_hidden_states,
  527. attention_mask,
  528. head_mask,
  529. output_attentions=output_attentions,
  530. )
  531. if entity_hidden_states is None:
  532. concat_attention_output = self_attention_outputs[0]
  533. else:
  534. concat_attention_output = torch.cat(self_attention_outputs[:2], dim=1)
  535. outputs = self_attention_outputs[2:] # add self attentions if we output attention weights
  536. layer_output = apply_chunking_to_forward(
  537. self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, concat_attention_output
  538. )
  539. word_layer_output = layer_output[:, :word_size, :]
  540. if entity_hidden_states is None:
  541. entity_layer_output = None
  542. else:
  543. entity_layer_output = layer_output[:, word_size:, :]
  544. outputs = (word_layer_output, entity_layer_output) + outputs
  545. return outputs
  546. def feed_forward_chunk(self, attention_output):
  547. intermediate_output = self.intermediate(attention_output)
  548. layer_output = self.output(intermediate_output, attention_output)
  549. return layer_output
  550. class LukeEncoder(nn.Module):
  551. def __init__(self, config):
  552. super().__init__()
  553. self.config = config
  554. self.layer = nn.ModuleList([LukeLayer(config) for _ in range(config.num_hidden_layers)])
  555. self.gradient_checkpointing = False
  556. def forward(
  557. self,
  558. word_hidden_states,
  559. entity_hidden_states,
  560. attention_mask=None,
  561. head_mask=None,
  562. output_attentions=False,
  563. output_hidden_states=False,
  564. return_dict=True,
  565. ):
  566. all_word_hidden_states = () if output_hidden_states else None
  567. all_entity_hidden_states = () if output_hidden_states else None
  568. all_self_attentions = () if output_attentions else None
  569. for i, layer_module in enumerate(self.layer):
  570. if output_hidden_states:
  571. all_word_hidden_states = all_word_hidden_states + (word_hidden_states,)
  572. all_entity_hidden_states = all_entity_hidden_states + (entity_hidden_states,)
  573. layer_head_mask = head_mask[i] if head_mask is not None else None
  574. layer_outputs = layer_module(
  575. word_hidden_states,
  576. entity_hidden_states,
  577. attention_mask,
  578. layer_head_mask,
  579. output_attentions,
  580. )
  581. word_hidden_states = layer_outputs[0]
  582. if entity_hidden_states is not None:
  583. entity_hidden_states = layer_outputs[1]
  584. if output_attentions:
  585. all_self_attentions = all_self_attentions + (layer_outputs[2],)
  586. if output_hidden_states:
  587. all_word_hidden_states = all_word_hidden_states + (word_hidden_states,)
  588. all_entity_hidden_states = all_entity_hidden_states + (entity_hidden_states,)
  589. if not return_dict:
  590. return tuple(
  591. v
  592. for v in [
  593. word_hidden_states,
  594. all_word_hidden_states,
  595. all_self_attentions,
  596. entity_hidden_states,
  597. all_entity_hidden_states,
  598. ]
  599. if v is not None
  600. )
  601. return BaseLukeModelOutput(
  602. last_hidden_state=word_hidden_states,
  603. hidden_states=all_word_hidden_states,
  604. attentions=all_self_attentions,
  605. entity_last_hidden_state=entity_hidden_states,
  606. entity_hidden_states=all_entity_hidden_states,
  607. )
  608. # Copied from transformers.models.bert.modeling_bert.BertPooler
  609. class LukePooler(nn.Module):
  610. def __init__(self, config):
  611. super().__init__()
  612. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  613. self.activation = nn.Tanh()
  614. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  615. # We "pool" the model by simply taking the hidden state corresponding
  616. # to the first token.
  617. first_token_tensor = hidden_states[:, 0]
  618. pooled_output = self.dense(first_token_tensor)
  619. pooled_output = self.activation(pooled_output)
  620. return pooled_output
  621. class EntityPredictionHeadTransform(nn.Module):
  622. def __init__(self, config):
  623. super().__init__()
  624. self.dense = nn.Linear(config.hidden_size, config.entity_emb_size)
  625. if isinstance(config.hidden_act, str):
  626. self.transform_act_fn = ACT2FN[config.hidden_act]
  627. else:
  628. self.transform_act_fn = config.hidden_act
  629. self.LayerNorm = nn.LayerNorm(config.entity_emb_size, eps=config.layer_norm_eps)
  630. def forward(self, hidden_states):
  631. hidden_states = self.dense(hidden_states)
  632. hidden_states = self.transform_act_fn(hidden_states)
  633. hidden_states = self.LayerNorm(hidden_states)
  634. return hidden_states
  635. class EntityPredictionHead(nn.Module):
  636. def __init__(self, config):
  637. super().__init__()
  638. self.config = config
  639. self.transform = EntityPredictionHeadTransform(config)
  640. self.decoder = nn.Linear(config.entity_emb_size, config.entity_vocab_size, bias=False)
  641. self.bias = nn.Parameter(torch.zeros(config.entity_vocab_size))
  642. def forward(self, hidden_states):
  643. hidden_states = self.transform(hidden_states)
  644. hidden_states = self.decoder(hidden_states) + self.bias
  645. return hidden_states
  646. @auto_docstring
  647. class LukePreTrainedModel(PreTrainedModel):
  648. config: LukeConfig
  649. base_model_prefix = "luke"
  650. supports_gradient_checkpointing = True
  651. _no_split_modules = ["LukeAttention", "LukeEntityEmbeddings"]
  652. def _init_weights(self, module: nn.Module):
  653. """Initialize the weights"""
  654. if isinstance(module, nn.Linear):
  655. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  656. if module.bias is not None:
  657. module.bias.data.zero_()
  658. elif isinstance(module, nn.Embedding):
  659. if module.embedding_dim == 1: # embedding for bias parameters
  660. module.weight.data.zero_()
  661. else:
  662. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  663. if module.padding_idx is not None:
  664. module.weight.data[module.padding_idx].zero_()
  665. elif isinstance(module, nn.LayerNorm):
  666. module.bias.data.zero_()
  667. module.weight.data.fill_(1.0)
  668. @auto_docstring(
  669. custom_intro="""
  670. The bare LUKE model transformer outputting raw hidden-states for both word tokens and entities without any
  671. """
  672. )
  673. class LukeModel(LukePreTrainedModel):
  674. def __init__(self, config: LukeConfig, add_pooling_layer: bool = True):
  675. r"""
  676. add_pooling_layer (bool, *optional*, defaults to `True`):
  677. Whether to add a pooling layer
  678. """
  679. super().__init__(config)
  680. self.config = config
  681. self.embeddings = LukeEmbeddings(config)
  682. self.entity_embeddings = LukeEntityEmbeddings(config)
  683. self.encoder = LukeEncoder(config)
  684. self.pooler = LukePooler(config) if add_pooling_layer else None
  685. # Initialize weights and apply final processing
  686. self.post_init()
  687. def get_input_embeddings(self):
  688. return self.embeddings.word_embeddings
  689. def set_input_embeddings(self, value):
  690. self.embeddings.word_embeddings = value
  691. def get_entity_embeddings(self):
  692. return self.entity_embeddings.entity_embeddings
  693. def set_entity_embeddings(self, value):
  694. self.entity_embeddings.entity_embeddings = value
  695. def _prune_heads(self, heads_to_prune):
  696. raise NotImplementedError("LUKE does not support the pruning of attention heads")
  697. @auto_docstring
  698. def forward(
  699. self,
  700. input_ids: Optional[torch.LongTensor] = None,
  701. attention_mask: Optional[torch.FloatTensor] = None,
  702. token_type_ids: Optional[torch.LongTensor] = None,
  703. position_ids: Optional[torch.LongTensor] = None,
  704. entity_ids: Optional[torch.LongTensor] = None,
  705. entity_attention_mask: Optional[torch.FloatTensor] = None,
  706. entity_token_type_ids: Optional[torch.LongTensor] = None,
  707. entity_position_ids: Optional[torch.LongTensor] = None,
  708. head_mask: Optional[torch.FloatTensor] = None,
  709. inputs_embeds: Optional[torch.FloatTensor] = None,
  710. output_attentions: Optional[bool] = None,
  711. output_hidden_states: Optional[bool] = None,
  712. return_dict: Optional[bool] = None,
  713. ) -> Union[tuple, BaseLukeModelOutputWithPooling]:
  714. r"""
  715. entity_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`):
  716. Indices of entity tokens in the entity vocabulary.
  717. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  718. [`PreTrainedTokenizer.__call__`] for details.
  719. entity_attention_mask (`torch.FloatTensor` of shape `(batch_size, entity_length)`, *optional*):
  720. Mask to avoid performing attention on padding entity token indices. Mask values selected in `[0, 1]`:
  721. - 1 for entity tokens that are **not masked**,
  722. - 0 for entity tokens that are **masked**.
  723. entity_token_type_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`, *optional*):
  724. Segment token indices to indicate first and second portions of the entity token inputs. Indices are
  725. selected in `[0, 1]`:
  726. - 0 corresponds to a *portion A* entity token,
  727. - 1 corresponds to a *portion B* entity token.
  728. entity_position_ids (`torch.LongTensor` of shape `(batch_size, entity_length, max_mention_length)`, *optional*):
  729. Indices of positions of each input entity in the position embeddings. Selected in the range `[0,
  730. config.max_position_embeddings - 1]`.
  731. Examples:
  732. ```python
  733. >>> from transformers import AutoTokenizer, LukeModel
  734. >>> tokenizer = AutoTokenizer.from_pretrained("studio-ousia/luke-base")
  735. >>> model = LukeModel.from_pretrained("studio-ousia/luke-base")
  736. # Compute the contextualized entity representation corresponding to the entity mention "Beyoncé"
  737. >>> text = "Beyoncé lives in Los Angeles."
  738. >>> entity_spans = [(0, 7)] # character-based entity span corresponding to "Beyoncé"
  739. >>> encoding = tokenizer(text, entity_spans=entity_spans, add_prefix_space=True, return_tensors="pt")
  740. >>> outputs = model(**encoding)
  741. >>> word_last_hidden_state = outputs.last_hidden_state
  742. >>> entity_last_hidden_state = outputs.entity_last_hidden_state
  743. # Input Wikipedia entities to obtain enriched contextualized representations of word tokens
  744. >>> text = "Beyoncé lives in Los Angeles."
  745. >>> entities = [
  746. ... "Beyoncé",
  747. ... "Los Angeles",
  748. ... ] # Wikipedia entity titles corresponding to the entity mentions "Beyoncé" and "Los Angeles"
  749. >>> entity_spans = [
  750. ... (0, 7),
  751. ... (17, 28),
  752. ... ] # character-based entity spans corresponding to "Beyoncé" and "Los Angeles"
  753. >>> encoding = tokenizer(
  754. ... text, entities=entities, entity_spans=entity_spans, add_prefix_space=True, return_tensors="pt"
  755. ... )
  756. >>> outputs = model(**encoding)
  757. >>> word_last_hidden_state = outputs.last_hidden_state
  758. >>> entity_last_hidden_state = outputs.entity_last_hidden_state
  759. ```"""
  760. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  761. output_hidden_states = (
  762. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  763. )
  764. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  765. if input_ids is not None and inputs_embeds is not None:
  766. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  767. elif input_ids is not None:
  768. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  769. input_shape = input_ids.size()
  770. elif inputs_embeds is not None:
  771. input_shape = inputs_embeds.size()[:-1]
  772. else:
  773. raise ValueError("You have to specify either input_ids or inputs_embeds")
  774. batch_size, seq_length = input_shape
  775. device = input_ids.device if input_ids is not None else inputs_embeds.device
  776. if attention_mask is None:
  777. attention_mask = torch.ones((batch_size, seq_length), device=device)
  778. if token_type_ids is None:
  779. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
  780. if entity_ids is not None:
  781. entity_seq_length = entity_ids.size(1)
  782. if entity_attention_mask is None:
  783. entity_attention_mask = torch.ones((batch_size, entity_seq_length), device=device)
  784. if entity_token_type_ids is None:
  785. entity_token_type_ids = torch.zeros((batch_size, entity_seq_length), dtype=torch.long, device=device)
  786. # Prepare head mask if needed
  787. # 1.0 in head_mask indicate we keep the head
  788. # attention_probs has shape bsz x n_heads x N x N
  789. # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
  790. # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
  791. head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
  792. # First, compute word embeddings
  793. word_embedding_output = self.embeddings(
  794. input_ids=input_ids,
  795. position_ids=position_ids,
  796. token_type_ids=token_type_ids,
  797. inputs_embeds=inputs_embeds,
  798. )
  799. # Second, compute extended attention mask
  800. extended_attention_mask = self.get_extended_attention_mask(attention_mask, entity_attention_mask)
  801. # Third, compute entity embeddings and concatenate with word embeddings
  802. if entity_ids is None:
  803. entity_embedding_output = None
  804. else:
  805. entity_embedding_output = self.entity_embeddings(entity_ids, entity_position_ids, entity_token_type_ids)
  806. # Fourth, send embeddings through the model
  807. encoder_outputs = self.encoder(
  808. word_embedding_output,
  809. entity_embedding_output,
  810. attention_mask=extended_attention_mask,
  811. head_mask=head_mask,
  812. output_attentions=output_attentions,
  813. output_hidden_states=output_hidden_states,
  814. return_dict=return_dict,
  815. )
  816. # Fifth, get the output. LukeModel outputs the same as BertModel, namely sequence_output of shape (batch_size, seq_len, hidden_size)
  817. sequence_output = encoder_outputs[0]
  818. # Sixth, we compute the pooled_output, word_sequence_output and entity_sequence_output based on the sequence_output
  819. pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
  820. if not return_dict:
  821. return (sequence_output, pooled_output) + encoder_outputs[1:]
  822. return BaseLukeModelOutputWithPooling(
  823. last_hidden_state=sequence_output,
  824. pooler_output=pooled_output,
  825. hidden_states=encoder_outputs.hidden_states,
  826. attentions=encoder_outputs.attentions,
  827. entity_last_hidden_state=encoder_outputs.entity_last_hidden_state,
  828. entity_hidden_states=encoder_outputs.entity_hidden_states,
  829. )
  830. def get_extended_attention_mask(
  831. self, word_attention_mask: torch.LongTensor, entity_attention_mask: Optional[torch.LongTensor]
  832. ):
  833. """
  834. Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
  835. Arguments:
  836. word_attention_mask (`torch.LongTensor`):
  837. Attention mask for word tokens with ones indicating tokens to attend to, zeros for tokens to ignore.
  838. entity_attention_mask (`torch.LongTensor`, *optional*):
  839. Attention mask for entity tokens with ones indicating tokens to attend to, zeros for tokens to ignore.
  840. Returns:
  841. `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
  842. """
  843. attention_mask = word_attention_mask
  844. if entity_attention_mask is not None:
  845. attention_mask = torch.cat([attention_mask, entity_attention_mask], dim=-1)
  846. if attention_mask.dim() == 3:
  847. extended_attention_mask = attention_mask[:, None, :, :]
  848. elif attention_mask.dim() == 2:
  849. extended_attention_mask = attention_mask[:, None, None, :]
  850. else:
  851. raise ValueError(f"Wrong shape for attention_mask (shape {attention_mask.shape})")
  852. extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
  853. extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(self.dtype).min
  854. return extended_attention_mask
  855. def create_position_ids_from_input_ids(input_ids, padding_idx):
  856. """
  857. Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
  858. are ignored. This is modified from fairseq's `utils.make_positions`.
  859. Args:
  860. x: torch.Tensor x:
  861. Returns: torch.Tensor
  862. """
  863. # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
  864. mask = input_ids.ne(padding_idx).int()
  865. incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask)) * mask
  866. return incremental_indices.long() + padding_idx
  867. # Copied from transformers.models.roberta.modeling_roberta.RobertaLMHead
  868. class LukeLMHead(nn.Module):
  869. """Roberta Head for masked language modeling."""
  870. def __init__(self, config):
  871. super().__init__()
  872. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  873. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  874. self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
  875. self.bias = nn.Parameter(torch.zeros(config.vocab_size))
  876. self.decoder.bias = self.bias
  877. def forward(self, features, **kwargs):
  878. x = self.dense(features)
  879. x = gelu(x)
  880. x = self.layer_norm(x)
  881. # project back to size of vocabulary with bias
  882. x = self.decoder(x)
  883. return x
  884. def _tie_weights(self):
  885. # To tie those two weights if they get disconnected (on TPU or when the bias is resized)
  886. # For accelerate compatibility and to not break backward compatibility
  887. if self.decoder.bias.device.type == "meta":
  888. self.decoder.bias = self.bias
  889. else:
  890. self.bias = self.decoder.bias
  891. @auto_docstring(
  892. custom_intro="""
  893. The LUKE model with a language modeling head and entity prediction head on top for masked language modeling and
  894. masked entity prediction.
  895. """
  896. )
  897. class LukeForMaskedLM(LukePreTrainedModel):
  898. _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias", "entity_predictions.decoder.weight"]
  899. def __init__(self, config):
  900. super().__init__(config)
  901. self.luke = LukeModel(config)
  902. self.lm_head = LukeLMHead(config)
  903. self.entity_predictions = EntityPredictionHead(config)
  904. self.loss_fn = nn.CrossEntropyLoss()
  905. # Initialize weights and apply final processing
  906. self.post_init()
  907. def tie_weights(self):
  908. super().tie_weights()
  909. self._tie_or_clone_weights(self.entity_predictions.decoder, self.luke.entity_embeddings.entity_embeddings)
  910. def get_output_embeddings(self):
  911. return self.lm_head.decoder
  912. def set_output_embeddings(self, new_embeddings):
  913. self.lm_head.decoder = new_embeddings
  914. @auto_docstring
  915. def forward(
  916. self,
  917. input_ids: Optional[torch.LongTensor] = None,
  918. attention_mask: Optional[torch.FloatTensor] = None,
  919. token_type_ids: Optional[torch.LongTensor] = None,
  920. position_ids: Optional[torch.LongTensor] = None,
  921. entity_ids: Optional[torch.LongTensor] = None,
  922. entity_attention_mask: Optional[torch.LongTensor] = None,
  923. entity_token_type_ids: Optional[torch.LongTensor] = None,
  924. entity_position_ids: Optional[torch.LongTensor] = None,
  925. labels: Optional[torch.LongTensor] = None,
  926. entity_labels: Optional[torch.LongTensor] = None,
  927. head_mask: Optional[torch.FloatTensor] = None,
  928. inputs_embeds: Optional[torch.FloatTensor] = None,
  929. output_attentions: Optional[bool] = None,
  930. output_hidden_states: Optional[bool] = None,
  931. return_dict: Optional[bool] = None,
  932. ) -> Union[tuple, LukeMaskedLMOutput]:
  933. r"""
  934. entity_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`):
  935. Indices of entity tokens in the entity vocabulary.
  936. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  937. [`PreTrainedTokenizer.__call__`] for details.
  938. entity_attention_mask (`torch.FloatTensor` of shape `(batch_size, entity_length)`, *optional*):
  939. Mask to avoid performing attention on padding entity token indices. Mask values selected in `[0, 1]`:
  940. - 1 for entity tokens that are **not masked**,
  941. - 0 for entity tokens that are **masked**.
  942. entity_token_type_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`, *optional*):
  943. Segment token indices to indicate first and second portions of the entity token inputs. Indices are
  944. selected in `[0, 1]`:
  945. - 0 corresponds to a *portion A* entity token,
  946. - 1 corresponds to a *portion B* entity token.
  947. entity_position_ids (`torch.LongTensor` of shape `(batch_size, entity_length, max_mention_length)`, *optional*):
  948. Indices of positions of each input entity in the position embeddings. Selected in the range `[0,
  949. config.max_position_embeddings - 1]`.
  950. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  951. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  952. config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
  953. loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  954. entity_labels (`torch.LongTensor` of shape `(batch_size, entity_length)`, *optional*):
  955. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  956. config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
  957. loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  958. """
  959. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  960. outputs = self.luke(
  961. input_ids=input_ids,
  962. attention_mask=attention_mask,
  963. token_type_ids=token_type_ids,
  964. position_ids=position_ids,
  965. entity_ids=entity_ids,
  966. entity_attention_mask=entity_attention_mask,
  967. entity_token_type_ids=entity_token_type_ids,
  968. entity_position_ids=entity_position_ids,
  969. head_mask=head_mask,
  970. inputs_embeds=inputs_embeds,
  971. output_attentions=output_attentions,
  972. output_hidden_states=output_hidden_states,
  973. return_dict=True,
  974. )
  975. loss = None
  976. mlm_loss = None
  977. logits = self.lm_head(outputs.last_hidden_state)
  978. if labels is not None:
  979. # move labels to correct device to enable model parallelism
  980. labels = labels.to(logits.device)
  981. mlm_loss = self.loss_fn(logits.view(-1, self.config.vocab_size), labels.view(-1))
  982. if loss is None:
  983. loss = mlm_loss
  984. mep_loss = None
  985. entity_logits = None
  986. if outputs.entity_last_hidden_state is not None:
  987. entity_logits = self.entity_predictions(outputs.entity_last_hidden_state)
  988. if entity_labels is not None:
  989. mep_loss = self.loss_fn(entity_logits.view(-1, self.config.entity_vocab_size), entity_labels.view(-1))
  990. if loss is None:
  991. loss = mep_loss
  992. else:
  993. loss = loss + mep_loss
  994. if not return_dict:
  995. return tuple(
  996. v
  997. for v in [
  998. loss,
  999. mlm_loss,
  1000. mep_loss,
  1001. logits,
  1002. entity_logits,
  1003. outputs.hidden_states,
  1004. outputs.entity_hidden_states,
  1005. outputs.attentions,
  1006. ]
  1007. if v is not None
  1008. )
  1009. return LukeMaskedLMOutput(
  1010. loss=loss,
  1011. mlm_loss=mlm_loss,
  1012. mep_loss=mep_loss,
  1013. logits=logits,
  1014. entity_logits=entity_logits,
  1015. hidden_states=outputs.hidden_states,
  1016. entity_hidden_states=outputs.entity_hidden_states,
  1017. attentions=outputs.attentions,
  1018. )
  1019. @auto_docstring(
  1020. custom_intro="""
  1021. The LUKE model with a classification head on top (a linear layer on top of the hidden state of the first entity
  1022. token) for entity classification tasks, such as Open Entity.
  1023. """
  1024. )
  1025. class LukeForEntityClassification(LukePreTrainedModel):
  1026. def __init__(self, config):
  1027. super().__init__(config)
  1028. self.luke = LukeModel(config)
  1029. self.num_labels = config.num_labels
  1030. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  1031. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  1032. # Initialize weights and apply final processing
  1033. self.post_init()
  1034. @auto_docstring
  1035. def forward(
  1036. self,
  1037. input_ids: Optional[torch.LongTensor] = None,
  1038. attention_mask: Optional[torch.FloatTensor] = None,
  1039. token_type_ids: Optional[torch.LongTensor] = None,
  1040. position_ids: Optional[torch.LongTensor] = None,
  1041. entity_ids: Optional[torch.LongTensor] = None,
  1042. entity_attention_mask: Optional[torch.FloatTensor] = None,
  1043. entity_token_type_ids: Optional[torch.LongTensor] = None,
  1044. entity_position_ids: Optional[torch.LongTensor] = None,
  1045. head_mask: Optional[torch.FloatTensor] = None,
  1046. inputs_embeds: Optional[torch.FloatTensor] = None,
  1047. labels: Optional[torch.FloatTensor] = None,
  1048. output_attentions: Optional[bool] = None,
  1049. output_hidden_states: Optional[bool] = None,
  1050. return_dict: Optional[bool] = None,
  1051. ) -> Union[tuple, EntityClassificationOutput]:
  1052. r"""
  1053. entity_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`):
  1054. Indices of entity tokens in the entity vocabulary.
  1055. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1056. [`PreTrainedTokenizer.__call__`] for details.
  1057. entity_attention_mask (`torch.FloatTensor` of shape `(batch_size, entity_length)`, *optional*):
  1058. Mask to avoid performing attention on padding entity token indices. Mask values selected in `[0, 1]`:
  1059. - 1 for entity tokens that are **not masked**,
  1060. - 0 for entity tokens that are **masked**.
  1061. entity_token_type_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`, *optional*):
  1062. Segment token indices to indicate first and second portions of the entity token inputs. Indices are
  1063. selected in `[0, 1]`:
  1064. - 0 corresponds to a *portion A* entity token,
  1065. - 1 corresponds to a *portion B* entity token.
  1066. entity_position_ids (`torch.LongTensor` of shape `(batch_size, entity_length, max_mention_length)`, *optional*):
  1067. Indices of positions of each input entity in the position embeddings. Selected in the range `[0,
  1068. config.max_position_embeddings - 1]`.
  1069. labels (`torch.LongTensor` of shape `(batch_size,)` or `(batch_size, num_labels)`, *optional*):
  1070. Labels for computing the classification loss. If the shape is `(batch_size,)`, the cross entropy loss is
  1071. used for the single-label classification. In this case, labels should contain the indices that should be in
  1072. `[0, ..., config.num_labels - 1]`. If the shape is `(batch_size, num_labels)`, the binary cross entropy
  1073. loss is used for the multi-label classification. In this case, labels should only contain `[0, 1]`, where 0
  1074. and 1 indicate false and true, respectively.
  1075. Examples:
  1076. ```python
  1077. >>> from transformers import AutoTokenizer, LukeForEntityClassification
  1078. >>> tokenizer = AutoTokenizer.from_pretrained("studio-ousia/luke-large-finetuned-open-entity")
  1079. >>> model = LukeForEntityClassification.from_pretrained("studio-ousia/luke-large-finetuned-open-entity")
  1080. >>> text = "Beyoncé lives in Los Angeles."
  1081. >>> entity_spans = [(0, 7)] # character-based entity span corresponding to "Beyoncé"
  1082. >>> inputs = tokenizer(text, entity_spans=entity_spans, return_tensors="pt")
  1083. >>> outputs = model(**inputs)
  1084. >>> logits = outputs.logits
  1085. >>> predicted_class_idx = logits.argmax(-1).item()
  1086. >>> print("Predicted class:", model.config.id2label[predicted_class_idx])
  1087. Predicted class: person
  1088. ```"""
  1089. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1090. outputs = self.luke(
  1091. input_ids=input_ids,
  1092. attention_mask=attention_mask,
  1093. token_type_ids=token_type_ids,
  1094. position_ids=position_ids,
  1095. entity_ids=entity_ids,
  1096. entity_attention_mask=entity_attention_mask,
  1097. entity_token_type_ids=entity_token_type_ids,
  1098. entity_position_ids=entity_position_ids,
  1099. head_mask=head_mask,
  1100. inputs_embeds=inputs_embeds,
  1101. output_attentions=output_attentions,
  1102. output_hidden_states=output_hidden_states,
  1103. return_dict=True,
  1104. )
  1105. feature_vector = outputs.entity_last_hidden_state[:, 0, :]
  1106. feature_vector = self.dropout(feature_vector)
  1107. logits = self.classifier(feature_vector)
  1108. loss = None
  1109. if labels is not None:
  1110. # When the number of dimension of `labels` is 1, cross entropy is used as the loss function. The binary
  1111. # cross entropy is used otherwise.
  1112. # move labels to correct device to enable model parallelism
  1113. labels = labels.to(logits.device)
  1114. if labels.ndim == 1:
  1115. loss = nn.functional.cross_entropy(logits, labels)
  1116. else:
  1117. loss = nn.functional.binary_cross_entropy_with_logits(logits.view(-1), labels.view(-1).type_as(logits))
  1118. if not return_dict:
  1119. return tuple(
  1120. v
  1121. for v in [loss, logits, outputs.hidden_states, outputs.entity_hidden_states, outputs.attentions]
  1122. if v is not None
  1123. )
  1124. return EntityClassificationOutput(
  1125. loss=loss,
  1126. logits=logits,
  1127. hidden_states=outputs.hidden_states,
  1128. entity_hidden_states=outputs.entity_hidden_states,
  1129. attentions=outputs.attentions,
  1130. )
  1131. @auto_docstring(
  1132. custom_intro="""
  1133. The LUKE model with a classification head on top (a linear layer on top of the hidden states of the two entity
  1134. tokens) for entity pair classification tasks, such as TACRED.
  1135. """
  1136. )
  1137. class LukeForEntityPairClassification(LukePreTrainedModel):
  1138. def __init__(self, config):
  1139. super().__init__(config)
  1140. self.luke = LukeModel(config)
  1141. self.num_labels = config.num_labels
  1142. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  1143. self.classifier = nn.Linear(config.hidden_size * 2, config.num_labels, False)
  1144. # Initialize weights and apply final processing
  1145. self.post_init()
  1146. @auto_docstring
  1147. def forward(
  1148. self,
  1149. input_ids: Optional[torch.LongTensor] = None,
  1150. attention_mask: Optional[torch.FloatTensor] = None,
  1151. token_type_ids: Optional[torch.LongTensor] = None,
  1152. position_ids: Optional[torch.LongTensor] = None,
  1153. entity_ids: Optional[torch.LongTensor] = None,
  1154. entity_attention_mask: Optional[torch.FloatTensor] = None,
  1155. entity_token_type_ids: Optional[torch.LongTensor] = None,
  1156. entity_position_ids: Optional[torch.LongTensor] = None,
  1157. head_mask: Optional[torch.FloatTensor] = None,
  1158. inputs_embeds: Optional[torch.FloatTensor] = None,
  1159. labels: Optional[torch.LongTensor] = None,
  1160. output_attentions: Optional[bool] = None,
  1161. output_hidden_states: Optional[bool] = None,
  1162. return_dict: Optional[bool] = None,
  1163. ) -> Union[tuple, EntityPairClassificationOutput]:
  1164. r"""
  1165. entity_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`):
  1166. Indices of entity tokens in the entity vocabulary.
  1167. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1168. [`PreTrainedTokenizer.__call__`] for details.
  1169. entity_attention_mask (`torch.FloatTensor` of shape `(batch_size, entity_length)`, *optional*):
  1170. Mask to avoid performing attention on padding entity token indices. Mask values selected in `[0, 1]`:
  1171. - 1 for entity tokens that are **not masked**,
  1172. - 0 for entity tokens that are **masked**.
  1173. entity_token_type_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`, *optional*):
  1174. Segment token indices to indicate first and second portions of the entity token inputs. Indices are
  1175. selected in `[0, 1]`:
  1176. - 0 corresponds to a *portion A* entity token,
  1177. - 1 corresponds to a *portion B* entity token.
  1178. entity_position_ids (`torch.LongTensor` of shape `(batch_size, entity_length, max_mention_length)`, *optional*):
  1179. Indices of positions of each input entity in the position embeddings. Selected in the range `[0,
  1180. config.max_position_embeddings - 1]`.
  1181. labels (`torch.LongTensor` of shape `(batch_size,)` or `(batch_size, num_labels)`, *optional*):
  1182. Labels for computing the classification loss. If the shape is `(batch_size,)`, the cross entropy loss is
  1183. used for the single-label classification. In this case, labels should contain the indices that should be in
  1184. `[0, ..., config.num_labels - 1]`. If the shape is `(batch_size, num_labels)`, the binary cross entropy
  1185. loss is used for the multi-label classification. In this case, labels should only contain `[0, 1]`, where 0
  1186. and 1 indicate false and true, respectively.
  1187. Examples:
  1188. ```python
  1189. >>> from transformers import AutoTokenizer, LukeForEntityPairClassification
  1190. >>> tokenizer = AutoTokenizer.from_pretrained("studio-ousia/luke-large-finetuned-tacred")
  1191. >>> model = LukeForEntityPairClassification.from_pretrained("studio-ousia/luke-large-finetuned-tacred")
  1192. >>> text = "Beyoncé lives in Los Angeles."
  1193. >>> entity_spans = [
  1194. ... (0, 7),
  1195. ... (17, 28),
  1196. ... ] # character-based entity spans corresponding to "Beyoncé" and "Los Angeles"
  1197. >>> inputs = tokenizer(text, entity_spans=entity_spans, return_tensors="pt")
  1198. >>> outputs = model(**inputs)
  1199. >>> logits = outputs.logits
  1200. >>> predicted_class_idx = logits.argmax(-1).item()
  1201. >>> print("Predicted class:", model.config.id2label[predicted_class_idx])
  1202. Predicted class: per:cities_of_residence
  1203. ```"""
  1204. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1205. outputs = self.luke(
  1206. input_ids=input_ids,
  1207. attention_mask=attention_mask,
  1208. token_type_ids=token_type_ids,
  1209. position_ids=position_ids,
  1210. entity_ids=entity_ids,
  1211. entity_attention_mask=entity_attention_mask,
  1212. entity_token_type_ids=entity_token_type_ids,
  1213. entity_position_ids=entity_position_ids,
  1214. head_mask=head_mask,
  1215. inputs_embeds=inputs_embeds,
  1216. output_attentions=output_attentions,
  1217. output_hidden_states=output_hidden_states,
  1218. return_dict=True,
  1219. )
  1220. feature_vector = torch.cat(
  1221. [outputs.entity_last_hidden_state[:, 0, :], outputs.entity_last_hidden_state[:, 1, :]], dim=1
  1222. )
  1223. feature_vector = self.dropout(feature_vector)
  1224. logits = self.classifier(feature_vector)
  1225. loss = None
  1226. if labels is not None:
  1227. # When the number of dimension of `labels` is 1, cross entropy is used as the loss function. The binary
  1228. # cross entropy is used otherwise.
  1229. # move labels to correct device to enable model parallelism
  1230. labels = labels.to(logits.device)
  1231. if labels.ndim == 1:
  1232. loss = nn.functional.cross_entropy(logits, labels)
  1233. else:
  1234. loss = nn.functional.binary_cross_entropy_with_logits(logits.view(-1), labels.view(-1).type_as(logits))
  1235. if not return_dict:
  1236. return tuple(
  1237. v
  1238. for v in [loss, logits, outputs.hidden_states, outputs.entity_hidden_states, outputs.attentions]
  1239. if v is not None
  1240. )
  1241. return EntityPairClassificationOutput(
  1242. loss=loss,
  1243. logits=logits,
  1244. hidden_states=outputs.hidden_states,
  1245. entity_hidden_states=outputs.entity_hidden_states,
  1246. attentions=outputs.attentions,
  1247. )
  1248. @auto_docstring(
  1249. custom_intro="""
  1250. The LUKE model with a span classification head on top (a linear layer on top of the hidden states output) for tasks
  1251. such as named entity recognition.
  1252. """
  1253. )
  1254. class LukeForEntitySpanClassification(LukePreTrainedModel):
  1255. def __init__(self, config):
  1256. super().__init__(config)
  1257. self.luke = LukeModel(config)
  1258. self.num_labels = config.num_labels
  1259. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  1260. self.classifier = nn.Linear(config.hidden_size * 3, config.num_labels)
  1261. # Initialize weights and apply final processing
  1262. self.post_init()
  1263. @auto_docstring
  1264. def forward(
  1265. self,
  1266. input_ids: Optional[torch.LongTensor] = None,
  1267. attention_mask: Optional[torch.FloatTensor] = None,
  1268. token_type_ids: Optional[torch.LongTensor] = None,
  1269. position_ids: Optional[torch.LongTensor] = None,
  1270. entity_ids: Optional[torch.LongTensor] = None,
  1271. entity_attention_mask: Optional[torch.LongTensor] = None,
  1272. entity_token_type_ids: Optional[torch.LongTensor] = None,
  1273. entity_position_ids: Optional[torch.LongTensor] = None,
  1274. entity_start_positions: Optional[torch.LongTensor] = None,
  1275. entity_end_positions: Optional[torch.LongTensor] = None,
  1276. head_mask: Optional[torch.FloatTensor] = None,
  1277. inputs_embeds: Optional[torch.FloatTensor] = None,
  1278. labels: Optional[torch.LongTensor] = None,
  1279. output_attentions: Optional[bool] = None,
  1280. output_hidden_states: Optional[bool] = None,
  1281. return_dict: Optional[bool] = None,
  1282. ) -> Union[tuple, EntitySpanClassificationOutput]:
  1283. r"""
  1284. entity_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`):
  1285. Indices of entity tokens in the entity vocabulary.
  1286. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1287. [`PreTrainedTokenizer.__call__`] for details.
  1288. entity_attention_mask (`torch.FloatTensor` of shape `(batch_size, entity_length)`, *optional*):
  1289. Mask to avoid performing attention on padding entity token indices. Mask values selected in `[0, 1]`:
  1290. - 1 for entity tokens that are **not masked**,
  1291. - 0 for entity tokens that are **masked**.
  1292. entity_token_type_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`, *optional*):
  1293. Segment token indices to indicate first and second portions of the entity token inputs. Indices are
  1294. selected in `[0, 1]`:
  1295. - 0 corresponds to a *portion A* entity token,
  1296. - 1 corresponds to a *portion B* entity token.
  1297. entity_position_ids (`torch.LongTensor` of shape `(batch_size, entity_length, max_mention_length)`, *optional*):
  1298. Indices of positions of each input entity in the position embeddings. Selected in the range `[0,
  1299. config.max_position_embeddings - 1]`.
  1300. entity_start_positions (`torch.LongTensor`):
  1301. The start positions of entities in the word token sequence.
  1302. entity_end_positions (`torch.LongTensor`):
  1303. The end positions of entities in the word token sequence.
  1304. labels (`torch.LongTensor` of shape `(batch_size, entity_length)` or `(batch_size, entity_length, num_labels)`, *optional*):
  1305. Labels for computing the classification loss. If the shape is `(batch_size, entity_length)`, the cross
  1306. entropy loss is used for the single-label classification. In this case, labels should contain the indices
  1307. that should be in `[0, ..., config.num_labels - 1]`. If the shape is `(batch_size, entity_length,
  1308. num_labels)`, the binary cross entropy loss is used for the multi-label classification. In this case,
  1309. labels should only contain `[0, 1]`, where 0 and 1 indicate false and true, respectively.
  1310. Examples:
  1311. ```python
  1312. >>> from transformers import AutoTokenizer, LukeForEntitySpanClassification
  1313. >>> tokenizer = AutoTokenizer.from_pretrained("studio-ousia/luke-large-finetuned-conll-2003")
  1314. >>> model = LukeForEntitySpanClassification.from_pretrained("studio-ousia/luke-large-finetuned-conll-2003")
  1315. >>> text = "Beyoncé lives in Los Angeles"
  1316. # List all possible entity spans in the text
  1317. >>> word_start_positions = [0, 8, 14, 17, 21] # character-based start positions of word tokens
  1318. >>> word_end_positions = [7, 13, 16, 20, 28] # character-based end positions of word tokens
  1319. >>> entity_spans = []
  1320. >>> for i, start_pos in enumerate(word_start_positions):
  1321. ... for end_pos in word_end_positions[i:]:
  1322. ... entity_spans.append((start_pos, end_pos))
  1323. >>> inputs = tokenizer(text, entity_spans=entity_spans, return_tensors="pt")
  1324. >>> outputs = model(**inputs)
  1325. >>> logits = outputs.logits
  1326. >>> predicted_class_indices = logits.argmax(-1).squeeze().tolist()
  1327. >>> for span, predicted_class_idx in zip(entity_spans, predicted_class_indices):
  1328. ... if predicted_class_idx != 0:
  1329. ... print(text[span[0] : span[1]], model.config.id2label[predicted_class_idx])
  1330. Beyoncé PER
  1331. Los Angeles LOC
  1332. ```"""
  1333. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1334. outputs = self.luke(
  1335. input_ids=input_ids,
  1336. attention_mask=attention_mask,
  1337. token_type_ids=token_type_ids,
  1338. position_ids=position_ids,
  1339. entity_ids=entity_ids,
  1340. entity_attention_mask=entity_attention_mask,
  1341. entity_token_type_ids=entity_token_type_ids,
  1342. entity_position_ids=entity_position_ids,
  1343. head_mask=head_mask,
  1344. inputs_embeds=inputs_embeds,
  1345. output_attentions=output_attentions,
  1346. output_hidden_states=output_hidden_states,
  1347. return_dict=True,
  1348. )
  1349. hidden_size = outputs.last_hidden_state.size(-1)
  1350. entity_start_positions = entity_start_positions.unsqueeze(-1).expand(-1, -1, hidden_size)
  1351. if entity_start_positions.device != outputs.last_hidden_state.device:
  1352. entity_start_positions = entity_start_positions.to(outputs.last_hidden_state.device)
  1353. start_states = torch.gather(outputs.last_hidden_state, -2, entity_start_positions)
  1354. entity_end_positions = entity_end_positions.unsqueeze(-1).expand(-1, -1, hidden_size)
  1355. if entity_end_positions.device != outputs.last_hidden_state.device:
  1356. entity_end_positions = entity_end_positions.to(outputs.last_hidden_state.device)
  1357. end_states = torch.gather(outputs.last_hidden_state, -2, entity_end_positions)
  1358. feature_vector = torch.cat([start_states, end_states, outputs.entity_last_hidden_state], dim=2)
  1359. feature_vector = self.dropout(feature_vector)
  1360. logits = self.classifier(feature_vector)
  1361. loss = None
  1362. if labels is not None:
  1363. # move labels to correct device to enable model parallelism
  1364. labels = labels.to(logits.device)
  1365. # When the number of dimension of `labels` is 2, cross entropy is used as the loss function. The binary
  1366. # cross entropy is used otherwise.
  1367. if labels.ndim == 2:
  1368. loss = nn.functional.cross_entropy(logits.view(-1, self.num_labels), labels.view(-1))
  1369. else:
  1370. loss = nn.functional.binary_cross_entropy_with_logits(logits.view(-1), labels.view(-1).type_as(logits))
  1371. if not return_dict:
  1372. return tuple(
  1373. v
  1374. for v in [loss, logits, outputs.hidden_states, outputs.entity_hidden_states, outputs.attentions]
  1375. if v is not None
  1376. )
  1377. return EntitySpanClassificationOutput(
  1378. loss=loss,
  1379. logits=logits,
  1380. hidden_states=outputs.hidden_states,
  1381. entity_hidden_states=outputs.entity_hidden_states,
  1382. attentions=outputs.attentions,
  1383. )
  1384. @auto_docstring(
  1385. custom_intro="""
  1386. The LUKE Model transformer with a sequence classification/regression head on top (a linear layer on top of the
  1387. pooled output) e.g. for GLUE tasks.
  1388. """
  1389. )
  1390. class LukeForSequenceClassification(LukePreTrainedModel):
  1391. def __init__(self, config):
  1392. super().__init__(config)
  1393. self.num_labels = config.num_labels
  1394. self.luke = LukeModel(config)
  1395. self.dropout = nn.Dropout(
  1396. config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
  1397. )
  1398. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  1399. # Initialize weights and apply final processing
  1400. self.post_init()
  1401. @auto_docstring
  1402. def forward(
  1403. self,
  1404. input_ids: Optional[torch.LongTensor] = None,
  1405. attention_mask: Optional[torch.FloatTensor] = None,
  1406. token_type_ids: Optional[torch.LongTensor] = None,
  1407. position_ids: Optional[torch.LongTensor] = None,
  1408. entity_ids: Optional[torch.LongTensor] = None,
  1409. entity_attention_mask: Optional[torch.FloatTensor] = None,
  1410. entity_token_type_ids: Optional[torch.LongTensor] = None,
  1411. entity_position_ids: Optional[torch.LongTensor] = None,
  1412. head_mask: Optional[torch.FloatTensor] = None,
  1413. inputs_embeds: Optional[torch.FloatTensor] = None,
  1414. labels: Optional[torch.FloatTensor] = None,
  1415. output_attentions: Optional[bool] = None,
  1416. output_hidden_states: Optional[bool] = None,
  1417. return_dict: Optional[bool] = None,
  1418. ) -> Union[tuple, LukeSequenceClassifierOutput]:
  1419. r"""
  1420. entity_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`):
  1421. Indices of entity tokens in the entity vocabulary.
  1422. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1423. [`PreTrainedTokenizer.__call__`] for details.
  1424. entity_attention_mask (`torch.FloatTensor` of shape `(batch_size, entity_length)`, *optional*):
  1425. Mask to avoid performing attention on padding entity token indices. Mask values selected in `[0, 1]`:
  1426. - 1 for entity tokens that are **not masked**,
  1427. - 0 for entity tokens that are **masked**.
  1428. entity_token_type_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`, *optional*):
  1429. Segment token indices to indicate first and second portions of the entity token inputs. Indices are
  1430. selected in `[0, 1]`:
  1431. - 0 corresponds to a *portion A* entity token,
  1432. - 1 corresponds to a *portion B* entity token.
  1433. entity_position_ids (`torch.LongTensor` of shape `(batch_size, entity_length, max_mention_length)`, *optional*):
  1434. Indices of positions of each input entity in the position embeddings. Selected in the range `[0,
  1435. config.max_position_embeddings - 1]`.
  1436. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1437. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1438. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  1439. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1440. """
  1441. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1442. outputs = self.luke(
  1443. input_ids=input_ids,
  1444. attention_mask=attention_mask,
  1445. token_type_ids=token_type_ids,
  1446. position_ids=position_ids,
  1447. entity_ids=entity_ids,
  1448. entity_attention_mask=entity_attention_mask,
  1449. entity_token_type_ids=entity_token_type_ids,
  1450. entity_position_ids=entity_position_ids,
  1451. head_mask=head_mask,
  1452. inputs_embeds=inputs_embeds,
  1453. output_attentions=output_attentions,
  1454. output_hidden_states=output_hidden_states,
  1455. return_dict=True,
  1456. )
  1457. pooled_output = outputs.pooler_output
  1458. pooled_output = self.dropout(pooled_output)
  1459. logits = self.classifier(pooled_output)
  1460. loss = None
  1461. if labels is not None:
  1462. # move labels to correct device to enable model parallelism
  1463. labels = labels.to(logits.device)
  1464. if self.config.problem_type is None:
  1465. if self.num_labels == 1:
  1466. self.config.problem_type = "regression"
  1467. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  1468. self.config.problem_type = "single_label_classification"
  1469. else:
  1470. self.config.problem_type = "multi_label_classification"
  1471. if self.config.problem_type == "regression":
  1472. loss_fct = MSELoss()
  1473. if self.num_labels == 1:
  1474. loss = loss_fct(logits.squeeze(), labels.squeeze())
  1475. else:
  1476. loss = loss_fct(logits, labels)
  1477. elif self.config.problem_type == "single_label_classification":
  1478. loss_fct = CrossEntropyLoss()
  1479. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  1480. elif self.config.problem_type == "multi_label_classification":
  1481. loss_fct = BCEWithLogitsLoss()
  1482. loss = loss_fct(logits, labels)
  1483. if not return_dict:
  1484. return tuple(
  1485. v
  1486. for v in [loss, logits, outputs.hidden_states, outputs.entity_hidden_states, outputs.attentions]
  1487. if v is not None
  1488. )
  1489. return LukeSequenceClassifierOutput(
  1490. loss=loss,
  1491. logits=logits,
  1492. hidden_states=outputs.hidden_states,
  1493. entity_hidden_states=outputs.entity_hidden_states,
  1494. attentions=outputs.attentions,
  1495. )
  1496. @auto_docstring(
  1497. custom_intro="""
  1498. The LUKE Model with a token classification head on top (a linear layer on top of the hidden-states output). To
  1499. solve Named-Entity Recognition (NER) task using LUKE, `LukeForEntitySpanClassification` is more suitable than this
  1500. class.
  1501. """
  1502. )
  1503. class LukeForTokenClassification(LukePreTrainedModel):
  1504. def __init__(self, config):
  1505. super().__init__(config)
  1506. self.num_labels = config.num_labels
  1507. self.luke = LukeModel(config, add_pooling_layer=False)
  1508. self.dropout = nn.Dropout(
  1509. config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
  1510. )
  1511. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  1512. # Initialize weights and apply final processing
  1513. self.post_init()
  1514. @auto_docstring
  1515. def forward(
  1516. self,
  1517. input_ids: Optional[torch.LongTensor] = None,
  1518. attention_mask: Optional[torch.FloatTensor] = None,
  1519. token_type_ids: Optional[torch.LongTensor] = None,
  1520. position_ids: Optional[torch.LongTensor] = None,
  1521. entity_ids: Optional[torch.LongTensor] = None,
  1522. entity_attention_mask: Optional[torch.FloatTensor] = None,
  1523. entity_token_type_ids: Optional[torch.LongTensor] = None,
  1524. entity_position_ids: Optional[torch.LongTensor] = None,
  1525. head_mask: Optional[torch.FloatTensor] = None,
  1526. inputs_embeds: Optional[torch.FloatTensor] = None,
  1527. labels: Optional[torch.FloatTensor] = None,
  1528. output_attentions: Optional[bool] = None,
  1529. output_hidden_states: Optional[bool] = None,
  1530. return_dict: Optional[bool] = None,
  1531. ) -> Union[tuple, LukeTokenClassifierOutput]:
  1532. r"""
  1533. entity_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`):
  1534. Indices of entity tokens in the entity vocabulary.
  1535. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1536. [`PreTrainedTokenizer.__call__`] for details.
  1537. entity_attention_mask (`torch.FloatTensor` of shape `(batch_size, entity_length)`, *optional*):
  1538. Mask to avoid performing attention on padding entity token indices. Mask values selected in `[0, 1]`:
  1539. - 1 for entity tokens that are **not masked**,
  1540. - 0 for entity tokens that are **masked**.
  1541. entity_token_type_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`, *optional*):
  1542. Segment token indices to indicate first and second portions of the entity token inputs. Indices are
  1543. selected in `[0, 1]`:
  1544. - 0 corresponds to a *portion A* entity token,
  1545. - 1 corresponds to a *portion B* entity token.
  1546. entity_position_ids (`torch.LongTensor` of shape `(batch_size, entity_length, max_mention_length)`, *optional*):
  1547. Indices of positions of each input entity in the position embeddings. Selected in the range `[0,
  1548. config.max_position_embeddings - 1]`.
  1549. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1550. Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
  1551. num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
  1552. `input_ids` above)
  1553. """
  1554. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1555. outputs = self.luke(
  1556. input_ids=input_ids,
  1557. attention_mask=attention_mask,
  1558. token_type_ids=token_type_ids,
  1559. position_ids=position_ids,
  1560. entity_ids=entity_ids,
  1561. entity_attention_mask=entity_attention_mask,
  1562. entity_token_type_ids=entity_token_type_ids,
  1563. entity_position_ids=entity_position_ids,
  1564. head_mask=head_mask,
  1565. inputs_embeds=inputs_embeds,
  1566. output_attentions=output_attentions,
  1567. output_hidden_states=output_hidden_states,
  1568. return_dict=True,
  1569. )
  1570. sequence_output = outputs.last_hidden_state
  1571. sequence_output = self.dropout(sequence_output)
  1572. logits = self.classifier(sequence_output)
  1573. loss = None
  1574. if labels is not None:
  1575. # move labels to correct device to enable model parallelism
  1576. labels = labels.to(logits.device)
  1577. loss_fct = CrossEntropyLoss()
  1578. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  1579. if not return_dict:
  1580. return tuple(
  1581. v
  1582. for v in [loss, logits, outputs.hidden_states, outputs.entity_hidden_states, outputs.attentions]
  1583. if v is not None
  1584. )
  1585. return LukeTokenClassifierOutput(
  1586. loss=loss,
  1587. logits=logits,
  1588. hidden_states=outputs.hidden_states,
  1589. entity_hidden_states=outputs.entity_hidden_states,
  1590. attentions=outputs.attentions,
  1591. )
  1592. @auto_docstring
  1593. class LukeForQuestionAnswering(LukePreTrainedModel):
  1594. def __init__(self, config):
  1595. super().__init__(config)
  1596. self.num_labels = config.num_labels
  1597. self.luke = LukeModel(config, add_pooling_layer=False)
  1598. self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
  1599. # Initialize weights and apply final processing
  1600. self.post_init()
  1601. @auto_docstring
  1602. def forward(
  1603. self,
  1604. input_ids: Optional[torch.LongTensor] = None,
  1605. attention_mask: Optional[torch.FloatTensor] = None,
  1606. token_type_ids: Optional[torch.LongTensor] = None,
  1607. position_ids: Optional[torch.FloatTensor] = None,
  1608. entity_ids: Optional[torch.LongTensor] = None,
  1609. entity_attention_mask: Optional[torch.FloatTensor] = None,
  1610. entity_token_type_ids: Optional[torch.LongTensor] = None,
  1611. entity_position_ids: Optional[torch.LongTensor] = None,
  1612. head_mask: Optional[torch.FloatTensor] = None,
  1613. inputs_embeds: Optional[torch.FloatTensor] = None,
  1614. start_positions: Optional[torch.LongTensor] = None,
  1615. end_positions: Optional[torch.LongTensor] = None,
  1616. output_attentions: Optional[bool] = None,
  1617. output_hidden_states: Optional[bool] = None,
  1618. return_dict: Optional[bool] = None,
  1619. ) -> Union[tuple, LukeQuestionAnsweringModelOutput]:
  1620. r"""
  1621. entity_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`):
  1622. Indices of entity tokens in the entity vocabulary.
  1623. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1624. [`PreTrainedTokenizer.__call__`] for details.
  1625. entity_attention_mask (`torch.FloatTensor` of shape `(batch_size, entity_length)`, *optional*):
  1626. Mask to avoid performing attention on padding entity token indices. Mask values selected in `[0, 1]`:
  1627. - 1 for entity tokens that are **not masked**,
  1628. - 0 for entity tokens that are **masked**.
  1629. entity_token_type_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`, *optional*):
  1630. Segment token indices to indicate first and second portions of the entity token inputs. Indices are
  1631. selected in `[0, 1]`:
  1632. - 0 corresponds to a *portion A* entity token,
  1633. - 1 corresponds to a *portion B* entity token.
  1634. entity_position_ids (`torch.LongTensor` of shape `(batch_size, entity_length, max_mention_length)`, *optional*):
  1635. Indices of positions of each input entity in the position embeddings. Selected in the range `[0,
  1636. config.max_position_embeddings - 1]`.
  1637. """
  1638. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1639. outputs = self.luke(
  1640. input_ids=input_ids,
  1641. attention_mask=attention_mask,
  1642. token_type_ids=token_type_ids,
  1643. position_ids=position_ids,
  1644. entity_ids=entity_ids,
  1645. entity_attention_mask=entity_attention_mask,
  1646. entity_token_type_ids=entity_token_type_ids,
  1647. entity_position_ids=entity_position_ids,
  1648. head_mask=head_mask,
  1649. inputs_embeds=inputs_embeds,
  1650. output_attentions=output_attentions,
  1651. output_hidden_states=output_hidden_states,
  1652. return_dict=True,
  1653. )
  1654. sequence_output = outputs.last_hidden_state
  1655. logits = self.qa_outputs(sequence_output)
  1656. start_logits, end_logits = logits.split(1, dim=-1)
  1657. start_logits = start_logits.squeeze(-1)
  1658. end_logits = end_logits.squeeze(-1)
  1659. total_loss = None
  1660. if start_positions is not None and end_positions is not None:
  1661. # If we are on multi-GPU, split add a dimension
  1662. if len(start_positions.size()) > 1:
  1663. start_positions = start_positions.squeeze(-1)
  1664. if len(end_positions.size()) > 1:
  1665. end_positions = end_positions.squeeze(-1)
  1666. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  1667. ignored_index = start_logits.size(1)
  1668. start_positions.clamp_(0, ignored_index)
  1669. end_positions.clamp_(0, ignored_index)
  1670. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  1671. start_loss = loss_fct(start_logits, start_positions)
  1672. end_loss = loss_fct(end_logits, end_positions)
  1673. total_loss = (start_loss + end_loss) / 2
  1674. if not return_dict:
  1675. return tuple(
  1676. v
  1677. for v in [
  1678. total_loss,
  1679. start_logits,
  1680. end_logits,
  1681. outputs.hidden_states,
  1682. outputs.entity_hidden_states,
  1683. outputs.attentions,
  1684. ]
  1685. if v is not None
  1686. )
  1687. return LukeQuestionAnsweringModelOutput(
  1688. loss=total_loss,
  1689. start_logits=start_logits,
  1690. end_logits=end_logits,
  1691. hidden_states=outputs.hidden_states,
  1692. entity_hidden_states=outputs.entity_hidden_states,
  1693. attentions=outputs.attentions,
  1694. )
  1695. @auto_docstring
  1696. class LukeForMultipleChoice(LukePreTrainedModel):
  1697. def __init__(self, config):
  1698. super().__init__(config)
  1699. self.luke = LukeModel(config)
  1700. self.dropout = nn.Dropout(
  1701. config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
  1702. )
  1703. self.classifier = nn.Linear(config.hidden_size, 1)
  1704. # Initialize weights and apply final processing
  1705. self.post_init()
  1706. @auto_docstring
  1707. def forward(
  1708. self,
  1709. input_ids: Optional[torch.LongTensor] = None,
  1710. attention_mask: Optional[torch.FloatTensor] = None,
  1711. token_type_ids: Optional[torch.LongTensor] = None,
  1712. position_ids: Optional[torch.LongTensor] = None,
  1713. entity_ids: Optional[torch.LongTensor] = None,
  1714. entity_attention_mask: Optional[torch.FloatTensor] = None,
  1715. entity_token_type_ids: Optional[torch.LongTensor] = None,
  1716. entity_position_ids: Optional[torch.LongTensor] = None,
  1717. head_mask: Optional[torch.FloatTensor] = None,
  1718. inputs_embeds: Optional[torch.FloatTensor] = None,
  1719. labels: Optional[torch.FloatTensor] = None,
  1720. output_attentions: Optional[bool] = None,
  1721. output_hidden_states: Optional[bool] = None,
  1722. return_dict: Optional[bool] = None,
  1723. ) -> Union[tuple, LukeMultipleChoiceModelOutput]:
  1724. r"""
  1725. input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
  1726. Indices of input sequence tokens in the vocabulary.
  1727. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1728. [`PreTrainedTokenizer.__call__`] for details.
  1729. [What are input IDs?](../glossary#input-ids)
  1730. token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  1731. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
  1732. 1]`:
  1733. - 0 corresponds to a *sentence A* token,
  1734. - 1 corresponds to a *sentence B* token.
  1735. [What are token type IDs?](../glossary#token-type-ids)
  1736. position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  1737. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  1738. config.max_position_embeddings - 1]`.
  1739. [What are position IDs?](../glossary#position-ids)
  1740. entity_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`):
  1741. Indices of entity tokens in the entity vocabulary.
  1742. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1743. [`PreTrainedTokenizer.__call__`] for details.
  1744. entity_attention_mask (`torch.FloatTensor` of shape `(batch_size, entity_length)`, *optional*):
  1745. Mask to avoid performing attention on padding entity token indices. Mask values selected in `[0, 1]`:
  1746. - 1 for entity tokens that are **not masked**,
  1747. - 0 for entity tokens that are **masked**.
  1748. entity_token_type_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`, *optional*):
  1749. Segment token indices to indicate first and second portions of the entity token inputs. Indices are
  1750. selected in `[0, 1]`:
  1751. - 0 corresponds to a *portion A* entity token,
  1752. - 1 corresponds to a *portion B* entity token.
  1753. entity_position_ids (`torch.LongTensor` of shape `(batch_size, entity_length, max_mention_length)`, *optional*):
  1754. Indices of positions of each input entity in the position embeddings. Selected in the range `[0,
  1755. config.max_position_embeddings - 1]`.
  1756. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
  1757. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  1758. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  1759. model's internal embedding lookup matrix.
  1760. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1761. Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
  1762. num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
  1763. `input_ids` above)
  1764. """
  1765. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1766. num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
  1767. input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
  1768. attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
  1769. token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
  1770. position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
  1771. inputs_embeds = (
  1772. inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
  1773. if inputs_embeds is not None
  1774. else None
  1775. )
  1776. entity_ids = entity_ids.view(-1, entity_ids.size(-1)) if entity_ids is not None else None
  1777. entity_attention_mask = (
  1778. entity_attention_mask.view(-1, entity_attention_mask.size(-1))
  1779. if entity_attention_mask is not None
  1780. else None
  1781. )
  1782. entity_token_type_ids = (
  1783. entity_token_type_ids.view(-1, entity_token_type_ids.size(-1))
  1784. if entity_token_type_ids is not None
  1785. else None
  1786. )
  1787. entity_position_ids = (
  1788. entity_position_ids.view(-1, entity_position_ids.size(-2), entity_position_ids.size(-1))
  1789. if entity_position_ids is not None
  1790. else None
  1791. )
  1792. outputs = self.luke(
  1793. input_ids=input_ids,
  1794. attention_mask=attention_mask,
  1795. token_type_ids=token_type_ids,
  1796. position_ids=position_ids,
  1797. entity_ids=entity_ids,
  1798. entity_attention_mask=entity_attention_mask,
  1799. entity_token_type_ids=entity_token_type_ids,
  1800. entity_position_ids=entity_position_ids,
  1801. head_mask=head_mask,
  1802. inputs_embeds=inputs_embeds,
  1803. output_attentions=output_attentions,
  1804. output_hidden_states=output_hidden_states,
  1805. return_dict=True,
  1806. )
  1807. pooled_output = outputs.pooler_output
  1808. pooled_output = self.dropout(pooled_output)
  1809. logits = self.classifier(pooled_output)
  1810. reshaped_logits = logits.view(-1, num_choices)
  1811. loss = None
  1812. if labels is not None:
  1813. # move labels to correct device to enable model parallelism
  1814. labels = labels.to(reshaped_logits.device)
  1815. loss_fct = CrossEntropyLoss()
  1816. loss = loss_fct(reshaped_logits, labels)
  1817. if not return_dict:
  1818. return tuple(
  1819. v
  1820. for v in [
  1821. loss,
  1822. reshaped_logits,
  1823. outputs.hidden_states,
  1824. outputs.entity_hidden_states,
  1825. outputs.attentions,
  1826. ]
  1827. if v is not None
  1828. )
  1829. return LukeMultipleChoiceModelOutput(
  1830. loss=loss,
  1831. logits=reshaped_logits,
  1832. hidden_states=outputs.hidden_states,
  1833. entity_hidden_states=outputs.entity_hidden_states,
  1834. attentions=outputs.attentions,
  1835. )
  1836. __all__ = [
  1837. "LukeForEntityClassification",
  1838. "LukeForEntityPairClassification",
  1839. "LukeForEntitySpanClassification",
  1840. "LukeForMultipleChoice",
  1841. "LukeForQuestionAnswering",
  1842. "LukeForSequenceClassification",
  1843. "LukeForTokenClassification",
  1844. "LukeForMaskedLM",
  1845. "LukeModel",
  1846. "LukePreTrainedModel",
  1847. ]