modeling_xlnet.py 105 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389
  1. # coding=utf-8
  2. # Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
  3. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """
  17. PyTorch XLNet model.
  18. """
  19. import warnings
  20. from dataclasses import dataclass
  21. from typing import Callable, Optional, Union
  22. import torch
  23. from torch import nn
  24. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  25. from ...activations import ACT2FN, get_activation
  26. from ...generation import GenerationMixin
  27. from ...modeling_utils import PreTrainedModel
  28. from ...pytorch_utils import apply_chunking_to_forward
  29. from ...utils import ModelOutput, auto_docstring, logging
  30. from .configuration_xlnet import XLNetConfig
  31. logger = logging.get_logger(__name__)
  32. def build_tf_xlnet_to_pytorch_map(model, config, tf_weights=None):
  33. """
  34. A map of modules from TF to PyTorch. I use a map to keep the PyTorch model as identical to the original PyTorch
  35. model as possible.
  36. """
  37. tf_to_pt_map = {}
  38. if hasattr(model, "transformer"):
  39. if hasattr(model, "lm_loss"):
  40. # We will load also the output bias
  41. tf_to_pt_map["model/lm_loss/bias"] = model.lm_loss.bias
  42. if hasattr(model, "sequence_summary") and "model/sequnece_summary/summary/kernel" in tf_weights:
  43. # We will load also the sequence summary
  44. tf_to_pt_map["model/sequnece_summary/summary/kernel"] = model.sequence_summary.summary.weight
  45. tf_to_pt_map["model/sequnece_summary/summary/bias"] = model.sequence_summary.summary.bias
  46. if (
  47. hasattr(model, "logits_proj")
  48. and config.finetuning_task is not None
  49. and f"model/regression_{config.finetuning_task}/logit/kernel" in tf_weights
  50. ):
  51. tf_to_pt_map[f"model/regression_{config.finetuning_task}/logit/kernel"] = model.logits_proj.weight
  52. tf_to_pt_map[f"model/regression_{config.finetuning_task}/logit/bias"] = model.logits_proj.bias
  53. # Now load the rest of the transformer
  54. model = model.transformer
  55. # Embeddings and output
  56. tf_to_pt_map.update(
  57. {
  58. "model/transformer/word_embedding/lookup_table": model.word_embedding.weight,
  59. "model/transformer/mask_emb/mask_emb": model.mask_emb,
  60. }
  61. )
  62. # Transformer blocks
  63. for i, b in enumerate(model.layer):
  64. layer_str = f"model/transformer/layer_{i}/"
  65. tf_to_pt_map.update(
  66. {
  67. layer_str + "rel_attn/LayerNorm/gamma": b.rel_attn.layer_norm.weight,
  68. layer_str + "rel_attn/LayerNorm/beta": b.rel_attn.layer_norm.bias,
  69. layer_str + "rel_attn/o/kernel": b.rel_attn.o,
  70. layer_str + "rel_attn/q/kernel": b.rel_attn.q,
  71. layer_str + "rel_attn/k/kernel": b.rel_attn.k,
  72. layer_str + "rel_attn/r/kernel": b.rel_attn.r,
  73. layer_str + "rel_attn/v/kernel": b.rel_attn.v,
  74. layer_str + "ff/LayerNorm/gamma": b.ff.layer_norm.weight,
  75. layer_str + "ff/LayerNorm/beta": b.ff.layer_norm.bias,
  76. layer_str + "ff/layer_1/kernel": b.ff.layer_1.weight,
  77. layer_str + "ff/layer_1/bias": b.ff.layer_1.bias,
  78. layer_str + "ff/layer_2/kernel": b.ff.layer_2.weight,
  79. layer_str + "ff/layer_2/bias": b.ff.layer_2.bias,
  80. }
  81. )
  82. # Relative positioning biases
  83. if config.untie_r:
  84. r_r_list = []
  85. r_w_list = []
  86. r_s_list = []
  87. seg_embed_list = []
  88. for b in model.layer:
  89. r_r_list.append(b.rel_attn.r_r_bias)
  90. r_w_list.append(b.rel_attn.r_w_bias)
  91. r_s_list.append(b.rel_attn.r_s_bias)
  92. seg_embed_list.append(b.rel_attn.seg_embed)
  93. else:
  94. r_r_list = [model.r_r_bias]
  95. r_w_list = [model.r_w_bias]
  96. r_s_list = [model.r_s_bias]
  97. seg_embed_list = [model.seg_embed]
  98. tf_to_pt_map.update(
  99. {
  100. "model/transformer/r_r_bias": r_r_list,
  101. "model/transformer/r_w_bias": r_w_list,
  102. "model/transformer/r_s_bias": r_s_list,
  103. "model/transformer/seg_embed": seg_embed_list,
  104. }
  105. )
  106. return tf_to_pt_map
  107. def load_tf_weights_in_xlnet(model, config, tf_path):
  108. """Load tf checkpoints in a pytorch model"""
  109. try:
  110. import numpy as np
  111. import tensorflow as tf
  112. except ImportError:
  113. logger.error(
  114. "Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
  115. "https://www.tensorflow.org/install/ for installation instructions."
  116. )
  117. raise
  118. # Load weights from TF model
  119. init_vars = tf.train.list_variables(tf_path)
  120. tf_weights = {}
  121. for name, shape in init_vars:
  122. logger.info(f"Loading TF weight {name} with shape {shape}")
  123. array = tf.train.load_variable(tf_path, name)
  124. tf_weights[name] = array
  125. # Build TF to PyTorch weights loading map
  126. tf_to_pt_map = build_tf_xlnet_to_pytorch_map(model, config, tf_weights)
  127. for name, pointer in tf_to_pt_map.items():
  128. logger.info(f"Importing {name}")
  129. if name not in tf_weights:
  130. logger.info(f"{name} not in tf pre-trained weights, skipping")
  131. continue
  132. array = tf_weights[name]
  133. # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
  134. # which are not required for using pretrained model
  135. if "kernel" in name and ("ff" in name or "summary" in name or "logit" in name):
  136. logger.info("Transposing")
  137. array = np.transpose(array)
  138. if isinstance(pointer, list):
  139. # Here we will split the TF weights
  140. assert len(pointer) == array.shape[0], (
  141. f"Pointer length {len(pointer)} and array length {array.shape[0]} mismatched"
  142. )
  143. for i, p_i in enumerate(pointer):
  144. arr_i = array[i, ...]
  145. try:
  146. assert p_i.shape == arr_i.shape, (
  147. f"Pointer shape {p_i.shape} and array shape {arr_i.shape} mismatched"
  148. )
  149. except AssertionError as e:
  150. e.args += (p_i.shape, arr_i.shape)
  151. raise
  152. logger.info(f"Initialize PyTorch weight {name} for layer {i}")
  153. p_i.data = torch.from_numpy(arr_i)
  154. else:
  155. try:
  156. assert pointer.shape == array.shape, (
  157. f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
  158. )
  159. except AssertionError as e:
  160. e.args += (pointer.shape, array.shape)
  161. raise
  162. logger.info(f"Initialize PyTorch weight {name}")
  163. pointer.data = torch.from_numpy(array)
  164. tf_weights.pop(name, None)
  165. tf_weights.pop(name + "/Adam", None)
  166. tf_weights.pop(name + "/Adam_1", None)
  167. logger.info(f"Weights not copied to PyTorch model: {', '.join(tf_weights.keys())}")
  168. return model
  169. class XLNetRelativeAttention(nn.Module):
  170. def __init__(self, config):
  171. super().__init__()
  172. if config.d_model % config.n_head != 0:
  173. raise ValueError(
  174. f"The hidden size ({config.d_model}) is not a multiple of the number of attention "
  175. f"heads ({config.n_head}"
  176. )
  177. self.n_head = config.n_head
  178. self.d_head = config.d_head
  179. self.d_model = config.d_model
  180. self.scale = 1 / (config.d_head**0.5)
  181. self.q = nn.Parameter(torch.FloatTensor(config.d_model, self.n_head, self.d_head))
  182. self.k = nn.Parameter(torch.FloatTensor(config.d_model, self.n_head, self.d_head))
  183. self.v = nn.Parameter(torch.FloatTensor(config.d_model, self.n_head, self.d_head))
  184. self.o = nn.Parameter(torch.FloatTensor(config.d_model, self.n_head, self.d_head))
  185. self.r = nn.Parameter(torch.FloatTensor(config.d_model, self.n_head, self.d_head))
  186. self.r_r_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head))
  187. self.r_s_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head))
  188. self.r_w_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head))
  189. self.seg_embed = nn.Parameter(torch.FloatTensor(2, self.n_head, self.d_head))
  190. self.layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
  191. self.dropout = nn.Dropout(config.dropout)
  192. def prune_heads(self, heads):
  193. raise NotImplementedError
  194. @staticmethod
  195. def rel_shift(x, klen=-1):
  196. """perform relative shift to form the relative attention score."""
  197. x_size = x.shape
  198. x = x.reshape(x_size[1], x_size[0], x_size[2], x_size[3])
  199. x = x[1:, ...]
  200. x = x.reshape(x_size[0], x_size[1] - 1, x_size[2], x_size[3])
  201. # x = x[:, 0:klen, :, :]
  202. x = torch.index_select(x, 1, torch.arange(klen, device=x.device, dtype=torch.long))
  203. return x
  204. @staticmethod
  205. def rel_shift_bnij(x, klen=-1):
  206. x_size = x.shape
  207. x = x.reshape(x_size[0], x_size[1], x_size[3], x_size[2])
  208. x = x[:, :, 1:, :]
  209. x = x.reshape(x_size[0], x_size[1], x_size[2], x_size[3] - 1)
  210. # Note: the tensor-slice form was faster in my testing than torch.index_select
  211. # However, tracing doesn't like the nature of the slice, and if klen changes
  212. # during the run then it'll fail, whereas index_select will be fine.
  213. x = torch.index_select(x, 3, torch.arange(klen, device=x.device, dtype=torch.long))
  214. # x = x[:, :, :, :klen]
  215. return x
  216. def rel_attn_core(
  217. self,
  218. q_head,
  219. k_head_h,
  220. v_head_h,
  221. k_head_r,
  222. seg_mat=None,
  223. attn_mask=None,
  224. head_mask=None,
  225. output_attentions=False,
  226. ):
  227. """Core relative positional attention operations."""
  228. # content based attention score
  229. ac = torch.einsum("ibnd,jbnd->bnij", q_head + self.r_w_bias, k_head_h)
  230. # position based attention score
  231. bd = torch.einsum("ibnd,jbnd->bnij", q_head + self.r_r_bias, k_head_r)
  232. bd = self.rel_shift_bnij(bd, klen=ac.shape[3])
  233. # segment based attention score
  234. if seg_mat is None:
  235. ef = 0
  236. else:
  237. ef = torch.einsum("ibnd,snd->ibns", q_head + self.r_s_bias, self.seg_embed)
  238. ef = torch.einsum("ijbs,ibns->bnij", seg_mat, ef)
  239. # merge attention scores and perform masking
  240. attn_score = (ac + bd + ef) * self.scale
  241. if attn_mask is not None:
  242. # attn_score = attn_score * (1 - attn_mask) - 1e30 * attn_mask
  243. if attn_mask.dtype == torch.float16:
  244. attn_score = attn_score - 65500 * torch.einsum("ijbn->bnij", attn_mask)
  245. else:
  246. attn_score = attn_score - 1e30 * torch.einsum("ijbn->bnij", attn_mask)
  247. # attention probability
  248. attn_prob = nn.functional.softmax(attn_score, dim=3)
  249. attn_prob = self.dropout(attn_prob)
  250. # Mask heads if we want to
  251. if head_mask is not None:
  252. attn_prob = attn_prob * torch.einsum("ijbn->bnij", head_mask)
  253. # attention output
  254. attn_vec = torch.einsum("bnij,jbnd->ibnd", attn_prob, v_head_h)
  255. if output_attentions:
  256. return attn_vec, torch.einsum("bnij->ijbn", attn_prob)
  257. return attn_vec
  258. def post_attention(self, h, attn_vec, residual=True):
  259. """Post-attention processing."""
  260. # post-attention projection (back to `d_model`)
  261. attn_out = torch.einsum("ibnd,hnd->ibh", attn_vec, self.o)
  262. attn_out = self.dropout(attn_out)
  263. if residual:
  264. attn_out = attn_out + h
  265. output = self.layer_norm(attn_out)
  266. return output
  267. def forward(
  268. self,
  269. h,
  270. g,
  271. attn_mask_h,
  272. attn_mask_g,
  273. r,
  274. seg_mat,
  275. mems=None,
  276. target_mapping=None,
  277. head_mask=None,
  278. output_attentions=False,
  279. ):
  280. if g is not None:
  281. # Two-stream attention with relative positional encoding.
  282. # content based attention score
  283. if mems is not None and mems.dim() > 1:
  284. cat = torch.cat([mems, h], dim=0)
  285. else:
  286. cat = h
  287. # content-based key head
  288. k_head_h = torch.einsum("ibh,hnd->ibnd", cat, self.k)
  289. # content-based value head
  290. v_head_h = torch.einsum("ibh,hnd->ibnd", cat, self.v)
  291. # position-based key head
  292. k_head_r = torch.einsum("ibh,hnd->ibnd", r, self.r)
  293. # h-stream
  294. # content-stream query head
  295. q_head_h = torch.einsum("ibh,hnd->ibnd", h, self.q)
  296. # core attention ops
  297. attn_vec_h = self.rel_attn_core(
  298. q_head_h,
  299. k_head_h,
  300. v_head_h,
  301. k_head_r,
  302. seg_mat=seg_mat,
  303. attn_mask=attn_mask_h,
  304. head_mask=head_mask,
  305. output_attentions=output_attentions,
  306. )
  307. if output_attentions:
  308. attn_vec_h, attn_prob_h = attn_vec_h
  309. # post processing
  310. output_h = self.post_attention(h, attn_vec_h)
  311. # g-stream
  312. # query-stream query head
  313. q_head_g = torch.einsum("ibh,hnd->ibnd", g, self.q)
  314. # core attention ops
  315. if target_mapping is not None:
  316. q_head_g = torch.einsum("mbnd,mlb->lbnd", q_head_g, target_mapping)
  317. attn_vec_g = self.rel_attn_core(
  318. q_head_g,
  319. k_head_h,
  320. v_head_h,
  321. k_head_r,
  322. seg_mat=seg_mat,
  323. attn_mask=attn_mask_g,
  324. head_mask=head_mask,
  325. output_attentions=output_attentions,
  326. )
  327. if output_attentions:
  328. attn_vec_g, attn_prob_g = attn_vec_g
  329. attn_vec_g = torch.einsum("lbnd,mlb->mbnd", attn_vec_g, target_mapping)
  330. else:
  331. attn_vec_g = self.rel_attn_core(
  332. q_head_g,
  333. k_head_h,
  334. v_head_h,
  335. k_head_r,
  336. seg_mat=seg_mat,
  337. attn_mask=attn_mask_g,
  338. head_mask=head_mask,
  339. output_attentions=output_attentions,
  340. )
  341. if output_attentions:
  342. attn_vec_g, attn_prob_g = attn_vec_g
  343. # post processing
  344. output_g = self.post_attention(g, attn_vec_g)
  345. if output_attentions:
  346. attn_prob = attn_prob_h, attn_prob_g
  347. else:
  348. # Multi-head attention with relative positional encoding
  349. if mems is not None and mems.dim() > 1:
  350. cat = torch.cat([mems, h], dim=0)
  351. else:
  352. cat = h
  353. # content heads
  354. q_head_h = torch.einsum("ibh,hnd->ibnd", h, self.q)
  355. k_head_h = torch.einsum("ibh,hnd->ibnd", cat, self.k)
  356. v_head_h = torch.einsum("ibh,hnd->ibnd", cat, self.v)
  357. # positional heads
  358. # type casting for fp16 support
  359. k_head_r = torch.einsum("ibh,hnd->ibnd", r.type(self.r.dtype), self.r)
  360. # core attention ops
  361. attn_vec = self.rel_attn_core(
  362. q_head_h,
  363. k_head_h,
  364. v_head_h,
  365. k_head_r,
  366. seg_mat=seg_mat,
  367. attn_mask=attn_mask_h,
  368. head_mask=head_mask,
  369. output_attentions=output_attentions,
  370. )
  371. if output_attentions:
  372. attn_vec, attn_prob = attn_vec
  373. # post processing
  374. output_h = self.post_attention(h, attn_vec)
  375. output_g = None
  376. outputs = (output_h, output_g)
  377. if output_attentions:
  378. outputs = outputs + (attn_prob,)
  379. return outputs
  380. class XLNetFeedForward(nn.Module):
  381. def __init__(self, config):
  382. super().__init__()
  383. self.layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
  384. self.layer_1 = nn.Linear(config.d_model, config.d_inner)
  385. self.layer_2 = nn.Linear(config.d_inner, config.d_model)
  386. self.dropout = nn.Dropout(config.dropout)
  387. if isinstance(config.ff_activation, str):
  388. self.activation_function = ACT2FN[config.ff_activation]
  389. else:
  390. self.activation_function = config.ff_activation
  391. def forward(self, inp):
  392. output = inp
  393. output = self.layer_1(output)
  394. output = self.activation_function(output)
  395. output = self.dropout(output)
  396. output = self.layer_2(output)
  397. output = self.dropout(output)
  398. output = self.layer_norm(output + inp)
  399. return output
  400. class XLNetLayer(nn.Module):
  401. def __init__(self, config):
  402. super().__init__()
  403. self.rel_attn = XLNetRelativeAttention(config)
  404. self.ff = XLNetFeedForward(config)
  405. self.dropout = nn.Dropout(config.dropout)
  406. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  407. self.seq_len_dim = 1
  408. def forward(
  409. self,
  410. output_h,
  411. output_g,
  412. attn_mask_h,
  413. attn_mask_g,
  414. r,
  415. seg_mat,
  416. mems=None,
  417. target_mapping=None,
  418. head_mask=None,
  419. output_attentions=False,
  420. ):
  421. outputs = self.rel_attn(
  422. output_h,
  423. output_g,
  424. attn_mask_h,
  425. attn_mask_g,
  426. r,
  427. seg_mat,
  428. mems=mems,
  429. target_mapping=target_mapping,
  430. head_mask=head_mask,
  431. output_attentions=output_attentions,
  432. )
  433. output_h, output_g = outputs[:2]
  434. if output_g is not None:
  435. output_g = apply_chunking_to_forward(
  436. self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, output_g
  437. )
  438. output_h = apply_chunking_to_forward(self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, output_h)
  439. outputs = (output_h, output_g) + outputs[2:] # Add again attentions if there are there
  440. return outputs
  441. def ff_chunk(self, output_x):
  442. output_x = self.ff(output_x)
  443. return output_x
  444. # Copied from transformers.models.xlm.modeling_xlm.XLMPoolerStartLogits with XLM->XLNet
  445. class XLNetPoolerStartLogits(nn.Module):
  446. """
  447. Compute SQuAD start logits from sequence hidden states.
  448. Args:
  449. config ([`XLNetConfig`]):
  450. The config used by the model, will be used to grab the `hidden_size` of the model.
  451. """
  452. def __init__(self, config: XLNetConfig):
  453. super().__init__()
  454. self.dense = nn.Linear(config.hidden_size, 1)
  455. def forward(
  456. self, hidden_states: torch.FloatTensor, p_mask: Optional[torch.FloatTensor] = None
  457. ) -> torch.FloatTensor:
  458. """
  459. Args:
  460. hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`):
  461. The final hidden states of the model.
  462. p_mask (`torch.FloatTensor` of shape `(batch_size, seq_len)`, *optional*):
  463. Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token
  464. should be masked.
  465. Returns:
  466. `torch.FloatTensor`: The start logits for SQuAD.
  467. """
  468. x = self.dense(hidden_states).squeeze(-1)
  469. if p_mask is not None:
  470. if p_mask.dtype == torch.float16:
  471. x = x * (1 - p_mask) - 65500 * p_mask
  472. else:
  473. x = x * (1 - p_mask) - 1e30 * p_mask
  474. return x
  475. # Copied from transformers.models.xlm.modeling_xlm.XLMPoolerEndLogits with XLM->XLNet
  476. class XLNetPoolerEndLogits(nn.Module):
  477. """
  478. Compute SQuAD end logits from sequence hidden states.
  479. Args:
  480. config ([`XLNetConfig`]):
  481. The config used by the model, will be used to grab the `hidden_size` of the model and the `layer_norm_eps`
  482. to use.
  483. """
  484. def __init__(self, config: XLNetConfig):
  485. super().__init__()
  486. self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
  487. self.activation = nn.Tanh()
  488. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  489. self.dense_1 = nn.Linear(config.hidden_size, 1)
  490. def forward(
  491. self,
  492. hidden_states: torch.FloatTensor,
  493. start_states: Optional[torch.FloatTensor] = None,
  494. start_positions: Optional[torch.LongTensor] = None,
  495. p_mask: Optional[torch.FloatTensor] = None,
  496. ) -> torch.FloatTensor:
  497. """
  498. Args:
  499. hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`):
  500. The final hidden states of the model.
  501. start_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*):
  502. The hidden states of the first tokens for the labeled span.
  503. start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  504. The position of the first token for the labeled span.
  505. p_mask (`torch.FloatTensor` of shape `(batch_size, seq_len)`, *optional*):
  506. Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token
  507. should be masked.
  508. <Tip>
  509. One of `start_states` or `start_positions` should be not `None`. If both are set, `start_positions` overrides
  510. `start_states`.
  511. </Tip>
  512. Returns:
  513. `torch.FloatTensor`: The end logits for SQuAD.
  514. """
  515. assert start_states is not None or start_positions is not None, (
  516. "One of start_states, start_positions should be not None"
  517. )
  518. if start_positions is not None:
  519. slen, hsz = hidden_states.shape[-2:]
  520. start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
  521. start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz)
  522. start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz)
  523. x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1))
  524. x = self.activation(x)
  525. x = self.LayerNorm(x)
  526. x = self.dense_1(x).squeeze(-1)
  527. if p_mask is not None:
  528. if p_mask.dtype == torch.float16:
  529. x = x * (1 - p_mask) - 65500 * p_mask
  530. else:
  531. x = x * (1 - p_mask) - 1e30 * p_mask
  532. return x
  533. # Copied from transformers.models.xlm.modeling_xlm.XLMPoolerAnswerClass with XLM->XLNet
  534. class XLNetPoolerAnswerClass(nn.Module):
  535. """
  536. Compute SQuAD 2.0 answer class from classification and start tokens hidden states.
  537. Args:
  538. config ([`XLNetConfig`]):
  539. The config used by the model, will be used to grab the `hidden_size` of the model.
  540. """
  541. def __init__(self, config: XLNetConfig):
  542. super().__init__()
  543. self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
  544. self.activation = nn.Tanh()
  545. self.dense_1 = nn.Linear(config.hidden_size, 1, bias=False)
  546. def forward(
  547. self,
  548. hidden_states: torch.FloatTensor,
  549. start_states: Optional[torch.FloatTensor] = None,
  550. start_positions: Optional[torch.LongTensor] = None,
  551. cls_index: Optional[torch.LongTensor] = None,
  552. ) -> torch.FloatTensor:
  553. """
  554. Args:
  555. hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`):
  556. The final hidden states of the model.
  557. start_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*):
  558. The hidden states of the first tokens for the labeled span.
  559. start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  560. The position of the first token for the labeled span.
  561. cls_index (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  562. Position of the CLS token for each sentence in the batch. If `None`, takes the last token.
  563. <Tip>
  564. One of `start_states` or `start_positions` should be not `None`. If both are set, `start_positions` overrides
  565. `start_states`.
  566. </Tip>
  567. Returns:
  568. `torch.FloatTensor`: The SQuAD 2.0 answer class.
  569. """
  570. # No dependency on end_feature so that we can obtain one single `cls_logits` for each sample.
  571. hsz = hidden_states.shape[-1]
  572. assert start_states is not None or start_positions is not None, (
  573. "One of start_states, start_positions should be not None"
  574. )
  575. if start_positions is not None:
  576. start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
  577. start_states = hidden_states.gather(-2, start_positions).squeeze(-2) # shape (bsz, hsz)
  578. if cls_index is not None:
  579. cls_index = cls_index[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
  580. cls_token_state = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, hsz)
  581. else:
  582. cls_token_state = hidden_states[:, -1, :] # shape (bsz, hsz)
  583. x = self.dense_0(torch.cat([start_states, cls_token_state], dim=-1))
  584. x = self.activation(x)
  585. x = self.dense_1(x).squeeze(-1)
  586. return x
  587. # Copied from transformers.models.xlm.modeling_xlm.XLMSequenceSummary with XLM->XLNet
  588. class XLNetSequenceSummary(nn.Module):
  589. r"""
  590. Compute a single vector summary of a sequence hidden states.
  591. Args:
  592. config ([`XLNetConfig`]):
  593. The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
  594. config class of your model for the default values it uses):
  595. - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are:
  596. - `"last"` -- Take the last token hidden state (like XLNet)
  597. - `"first"` -- Take the first token hidden state (like Bert)
  598. - `"mean"` -- Take the mean of all tokens hidden states
  599. - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
  600. - `"attn"` -- Not implemented now, use multi-head attention
  601. - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction.
  602. - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes
  603. (otherwise to `config.hidden_size`).
  604. - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output,
  605. another string or `None` will add no activation.
  606. - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation.
  607. - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation.
  608. """
  609. def __init__(self, config: XLNetConfig):
  610. super().__init__()
  611. self.summary_type = getattr(config, "summary_type", "last")
  612. if self.summary_type == "attn":
  613. # We should use a standard multi-head attention module with absolute positional embedding for that.
  614. # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
  615. # We can probably just use the multi-head attention module of PyTorch >=1.1.0
  616. raise NotImplementedError
  617. self.summary = nn.Identity()
  618. if hasattr(config, "summary_use_proj") and config.summary_use_proj:
  619. if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
  620. num_classes = config.num_labels
  621. else:
  622. num_classes = config.hidden_size
  623. self.summary = nn.Linear(config.hidden_size, num_classes)
  624. activation_string = getattr(config, "summary_activation", None)
  625. self.activation: Callable = get_activation(activation_string) if activation_string else nn.Identity()
  626. self.first_dropout = nn.Identity()
  627. if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
  628. self.first_dropout = nn.Dropout(config.summary_first_dropout)
  629. self.last_dropout = nn.Identity()
  630. if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
  631. self.last_dropout = nn.Dropout(config.summary_last_dropout)
  632. def forward(
  633. self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None
  634. ) -> torch.FloatTensor:
  635. """
  636. Compute a single vector summary of a sequence hidden states.
  637. Args:
  638. hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`):
  639. The hidden states of the last layer.
  640. cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*):
  641. Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token.
  642. Returns:
  643. `torch.FloatTensor`: The summary of the sequence hidden states.
  644. """
  645. if self.summary_type == "last":
  646. output = hidden_states[:, -1]
  647. elif self.summary_type == "first":
  648. output = hidden_states[:, 0]
  649. elif self.summary_type == "mean":
  650. output = hidden_states.mean(dim=1)
  651. elif self.summary_type == "cls_index":
  652. if cls_index is None:
  653. cls_index = torch.full_like(
  654. hidden_states[..., :1, :],
  655. hidden_states.shape[-2] - 1,
  656. dtype=torch.long,
  657. )
  658. else:
  659. cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
  660. cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))
  661. # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
  662. output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
  663. elif self.summary_type == "attn":
  664. raise NotImplementedError
  665. output = self.first_dropout(output)
  666. output = self.summary(output)
  667. output = self.activation(output)
  668. output = self.last_dropout(output)
  669. return output
  670. @auto_docstring
  671. class XLNetPreTrainedModel(PreTrainedModel):
  672. config: XLNetConfig
  673. load_tf_weights = load_tf_weights_in_xlnet
  674. base_model_prefix = "transformer"
  675. def _init_weights(self, module):
  676. """Initialize the weights."""
  677. if isinstance(module, nn.Linear):
  678. # Slightly different from the TF version which uses truncated_normal for initialization
  679. # cf https://github.com/pytorch/pytorch/pull/5617
  680. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  681. if module.bias is not None:
  682. module.bias.data.zero_()
  683. elif isinstance(module, nn.Embedding):
  684. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  685. if module.padding_idx is not None:
  686. module.weight.data[module.padding_idx].zero_()
  687. elif isinstance(module, nn.LayerNorm):
  688. module.bias.data.zero_()
  689. module.weight.data.fill_(1.0)
  690. elif isinstance(module, XLNetRelativeAttention):
  691. for param in [
  692. module.q,
  693. module.k,
  694. module.v,
  695. module.o,
  696. module.r,
  697. module.r_r_bias,
  698. module.r_s_bias,
  699. module.r_w_bias,
  700. module.seg_embed,
  701. ]:
  702. param.data.normal_(mean=0.0, std=self.config.initializer_range)
  703. elif isinstance(module, XLNetModel):
  704. module.mask_emb.data.normal_(mean=0.0, std=self.config.initializer_range)
  705. @dataclass
  706. @auto_docstring(
  707. custom_intro="""
  708. Output type of [`XLNetModel`].
  709. """
  710. )
  711. class XLNetModelOutput(ModelOutput):
  712. r"""
  713. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_predict, hidden_size)`):
  714. Sequence of hidden-states at the last layer of the model.
  715. `num_predict` corresponds to `target_mapping.shape[1]`. If `target_mapping` is `None`, then `num_predict`
  716. corresponds to `sequence_length`.
  717. mems (`list[torch.FloatTensor]` of length `config.n_layers`):
  718. Contains pre-computed hidden-states. Can be used (see `mems` input) to speed up sequential decoding. The
  719. token ids which have their past given to this model should not be passed as `input_ids` as they have
  720. already been computed.
  721. """
  722. last_hidden_state: torch.FloatTensor
  723. mems: Optional[list[torch.FloatTensor]] = None
  724. hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  725. attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  726. @dataclass
  727. @auto_docstring(
  728. custom_intro="""
  729. Output type of [`XLNetLMHeadModel`].
  730. """
  731. )
  732. class XLNetLMHeadModelOutput(ModelOutput):
  733. r"""
  734. loss (`torch.FloatTensor` of shape *(1,)*, *optional*, returned when `labels` is provided):
  735. Language modeling loss (for next-token prediction).
  736. logits (`torch.FloatTensor` of shape `(batch_size, num_predict, config.vocab_size)`):
  737. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  738. `num_predict` corresponds to `target_mapping.shape[1]`. If `target_mapping` is `None`, then `num_predict`
  739. corresponds to `sequence_length`.
  740. mems (`list[torch.FloatTensor]` of length `config.n_layers`):
  741. Contains pre-computed hidden-states. Can be used (see `mems` input) to speed up sequential decoding. The
  742. token ids which have their past given to this model should not be passed as `input_ids` as they have
  743. already been computed.
  744. """
  745. loss: Optional[torch.FloatTensor] = None
  746. logits: Optional[torch.FloatTensor] = None
  747. mems: Optional[list[torch.FloatTensor]] = None
  748. hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  749. attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  750. @dataclass
  751. @auto_docstring(
  752. custom_intro="""
  753. Output type of [`XLNetForSequenceClassification`].
  754. """
  755. )
  756. class XLNetForSequenceClassificationOutput(ModelOutput):
  757. r"""
  758. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `label` is provided):
  759. Classification (or regression if config.num_labels==1) loss.
  760. logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
  761. Classification (or regression if config.num_labels==1) scores (before SoftMax).
  762. mems (`list[torch.FloatTensor]` of length `config.n_layers`):
  763. Contains pre-computed hidden-states. Can be used (see `mems` input) to speed up sequential decoding. The
  764. token ids which have their past given to this model should not be passed as `input_ids` as they have
  765. already been computed.
  766. """
  767. loss: Optional[torch.FloatTensor] = None
  768. logits: Optional[torch.FloatTensor] = None
  769. mems: Optional[list[torch.FloatTensor]] = None
  770. hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  771. attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  772. @dataclass
  773. @auto_docstring(
  774. custom_intro="""
  775. Output type of [`XLNetForTokenClassificationOutput`].
  776. """
  777. )
  778. class XLNetForTokenClassificationOutput(ModelOutput):
  779. r"""
  780. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  781. Classification loss.
  782. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`):
  783. Classification scores (before SoftMax).
  784. mems (`list[torch.FloatTensor]` of length `config.n_layers`):
  785. Contains pre-computed hidden-states. Can be used (see `mems` input) to speed up sequential decoding. The
  786. token ids which have their past given to this model should not be passed as `input_ids` as they have
  787. already been computed.
  788. """
  789. loss: Optional[torch.FloatTensor] = None
  790. logits: Optional[torch.FloatTensor] = None
  791. mems: Optional[list[torch.FloatTensor]] = None
  792. hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  793. attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  794. @dataclass
  795. @auto_docstring(
  796. custom_intro="""
  797. Output type of [`XLNetForMultipleChoice`].
  798. """
  799. )
  800. class XLNetForMultipleChoiceOutput(ModelOutput):
  801. r"""
  802. loss (`torch.FloatTensor` of shape *(1,)*, *optional*, returned when `labels` is provided):
  803. Classification loss.
  804. logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`):
  805. *num_choices* is the second dimension of the input tensors. (see *input_ids* above).
  806. Classification scores (before SoftMax).
  807. mems (`list[torch.FloatTensor]` of length `config.n_layers`):
  808. Contains pre-computed hidden-states. Can be used (see `mems` input) to speed up sequential decoding. The
  809. token ids which have their past given to this model should not be passed as `input_ids` as they have
  810. already been computed.
  811. """
  812. loss: Optional[torch.FloatTensor] = None
  813. logits: Optional[torch.FloatTensor] = None
  814. mems: Optional[list[torch.FloatTensor]] = None
  815. hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  816. attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  817. @dataclass
  818. @auto_docstring(
  819. custom_intro="""
  820. Output type of [`XLNetForQuestionAnsweringSimple`].
  821. """
  822. )
  823. class XLNetForQuestionAnsweringSimpleOutput(ModelOutput):
  824. r"""
  825. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  826. Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
  827. start_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length,)`):
  828. Span-start scores (before SoftMax).
  829. end_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length,)`):
  830. Span-end scores (before SoftMax).
  831. mems (`list[torch.FloatTensor]` of length `config.n_layers`):
  832. Contains pre-computed hidden-states. Can be used (see `mems` input) to speed up sequential decoding. The
  833. token ids which have their past given to this model should not be passed as `input_ids` as they have
  834. already been computed.
  835. """
  836. loss: Optional[torch.FloatTensor] = None
  837. start_logits: Optional[torch.FloatTensor] = None
  838. end_logits: Optional[torch.FloatTensor] = None
  839. mems: Optional[list[torch.FloatTensor]] = None
  840. hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  841. attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  842. @dataclass
  843. @auto_docstring(
  844. custom_intro="""
  845. Output type of [`XLNetForQuestionAnswering`].
  846. """
  847. )
  848. class XLNetForQuestionAnsweringOutput(ModelOutput):
  849. r"""
  850. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned if both `start_positions` and `end_positions` are provided):
  851. Classification loss as the sum of start token, end token (and is_impossible if provided) classification
  852. losses.
  853. start_top_log_probs (`torch.FloatTensor` of shape `(batch_size, config.start_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
  854. Log probabilities for the top config.start_n_top start token possibilities (beam-search).
  855. start_top_index (`torch.LongTensor` of shape `(batch_size, config.start_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
  856. Indices for the top config.start_n_top start token possibilities (beam-search).
  857. end_top_log_probs (`torch.FloatTensor` of shape `(batch_size, config.start_n_top * config.end_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
  858. Log probabilities for the top `config.start_n_top * config.end_n_top` end token possibilities
  859. (beam-search).
  860. end_top_index (`torch.LongTensor` of shape `(batch_size, config.start_n_top * config.end_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
  861. Indices for the top `config.start_n_top * config.end_n_top` end token possibilities (beam-search).
  862. cls_logits (`torch.FloatTensor` of shape `(batch_size,)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
  863. Log probabilities for the `is_impossible` label of the answers.
  864. mems (`list[torch.FloatTensor]` of length `config.n_layers`):
  865. Contains pre-computed hidden-states. Can be used (see `mems` input) to speed up sequential decoding. The
  866. token ids which have their past given to this model should not be passed as `input_ids` as they have
  867. already been computed.
  868. """
  869. loss: Optional[torch.FloatTensor] = None
  870. start_top_log_probs: Optional[torch.FloatTensor] = None
  871. start_top_index: Optional[torch.LongTensor] = None
  872. end_top_log_probs: Optional[torch.FloatTensor] = None
  873. end_top_index: Optional[torch.LongTensor] = None
  874. cls_logits: Optional[torch.FloatTensor] = None
  875. mems: Optional[list[torch.FloatTensor]] = None
  876. hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  877. attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  878. @auto_docstring
  879. class XLNetModel(XLNetPreTrainedModel):
  880. def __init__(self, config):
  881. super().__init__(config)
  882. self.mem_len = config.mem_len
  883. self.reuse_len = config.reuse_len
  884. self.d_model = config.d_model
  885. self.same_length = config.same_length
  886. self.attn_type = config.attn_type
  887. self.bi_data = config.bi_data
  888. self.clamp_len = config.clamp_len
  889. self.n_layer = config.n_layer
  890. self.word_embedding = nn.Embedding(config.vocab_size, config.d_model)
  891. self.mask_emb = nn.Parameter(torch.FloatTensor(1, 1, config.d_model))
  892. self.layer = nn.ModuleList([XLNetLayer(config) for _ in range(config.n_layer)])
  893. self.dropout = nn.Dropout(config.dropout)
  894. # Initialize weights and apply final processing
  895. self.post_init()
  896. def get_input_embeddings(self):
  897. return self.word_embedding
  898. def set_input_embeddings(self, new_embeddings):
  899. self.word_embedding = new_embeddings
  900. def _prune_heads(self, heads_to_prune):
  901. raise NotImplementedError
  902. def create_mask(self, qlen, mlen):
  903. """
  904. Creates causal attention mask. Float mask where 1.0 indicates masked, 0.0 indicates not-masked.
  905. Args:
  906. qlen: Sequence length
  907. mlen: Mask length
  908. ::
  909. same_length=False: same_length=True: <mlen > < qlen > <mlen > < qlen >
  910. ^ [0 0 0 0 0 1 1 1 1] [0 0 0 0 0 1 1 1 1]
  911. [0 0 0 0 0 0 1 1 1] [1 0 0 0 0 0 1 1 1]
  912. qlen [0 0 0 0 0 0 0 1 1] [1 1 0 0 0 0 0 1 1]
  913. [0 0 0 0 0 0 0 0 1] [1 1 1 0 0 0 0 0 1]
  914. v [0 0 0 0 0 0 0 0 0] [1 1 1 1 0 0 0 0 0]
  915. """
  916. mask = torch.ones((qlen, qlen + mlen), device=self.device)
  917. if self.same_length:
  918. mask_lo = mask[:, :qlen].tril(-1)
  919. mask.triu_(mlen + 1)
  920. mask[:, :qlen] += mask_lo
  921. else:
  922. mask.triu_(mlen + 1)
  923. return mask
  924. def cache_mem(self, curr_out, prev_mem):
  925. # cache hidden states into memory.
  926. if self.reuse_len is not None and self.reuse_len > 0:
  927. curr_out = curr_out[: self.reuse_len]
  928. if self.mem_len is None or self.mem_len == 0:
  929. # If `use_mems` is active but no `mem_len` is defined, the model behaves like GPT-2 at inference time
  930. # and returns all of the past and current hidden states.
  931. cutoff = 0
  932. else:
  933. # If `use_mems` is active and `mem_len` is defined, the model returns the last `mem_len` hidden
  934. # states. This is the preferred setting for training and long-form generation.
  935. cutoff = -self.mem_len
  936. if prev_mem is None:
  937. # if `use_mems` is active and `mem_len` is defined, the model
  938. new_mem = curr_out[cutoff:]
  939. else:
  940. new_mem = torch.cat([prev_mem, curr_out], dim=0)[cutoff:]
  941. return new_mem.detach()
  942. @staticmethod
  943. def positional_embedding(pos_seq, inv_freq, bsz=None):
  944. sinusoid_inp = torch.einsum("i,d->id", pos_seq, inv_freq)
  945. pos_emb = torch.cat([torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)], dim=-1)
  946. pos_emb = pos_emb[:, None, :]
  947. if bsz is not None:
  948. pos_emb = pos_emb.expand(-1, bsz, -1)
  949. return pos_emb
  950. def relative_positional_encoding(self, qlen, klen, bsz=None):
  951. # create relative positional encoding.
  952. freq_seq = torch.arange(0, self.d_model, 2.0, dtype=torch.int64).float()
  953. inv_freq = 1 / torch.pow(10000, (freq_seq / self.d_model))
  954. if self.attn_type == "bi":
  955. # beg, end = klen - 1, -qlen
  956. beg, end = klen, -qlen
  957. elif self.attn_type == "uni":
  958. # beg, end = klen - 1, -1
  959. beg, end = klen, -1
  960. else:
  961. raise ValueError(f"Unknown `attn_type` {self.attn_type}.")
  962. if self.bi_data:
  963. fwd_pos_seq = torch.arange(beg, end, -1.0, dtype=torch.int64).float()
  964. bwd_pos_seq = torch.arange(-beg, -end, 1.0, dtype=torch.int64).float()
  965. if self.clamp_len > 0:
  966. fwd_pos_seq = fwd_pos_seq.clamp(-self.clamp_len, self.clamp_len)
  967. bwd_pos_seq = bwd_pos_seq.clamp(-self.clamp_len, self.clamp_len)
  968. if bsz is not None:
  969. fwd_pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz // 2)
  970. bwd_pos_emb = self.positional_embedding(bwd_pos_seq, inv_freq, bsz // 2)
  971. else:
  972. fwd_pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq)
  973. bwd_pos_emb = self.positional_embedding(bwd_pos_seq, inv_freq)
  974. pos_emb = torch.cat([fwd_pos_emb, bwd_pos_emb], dim=1)
  975. else:
  976. fwd_pos_seq = torch.arange(beg, end, -1.0, dtype=torch.int64).float()
  977. if self.clamp_len > 0:
  978. fwd_pos_seq = fwd_pos_seq.clamp(-self.clamp_len, self.clamp_len)
  979. pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz)
  980. return pos_emb
  981. @auto_docstring
  982. def forward(
  983. self,
  984. input_ids: Optional[torch.Tensor] = None,
  985. attention_mask: Optional[torch.Tensor] = None,
  986. mems: Optional[torch.Tensor] = None,
  987. perm_mask: Optional[torch.Tensor] = None,
  988. target_mapping: Optional[torch.Tensor] = None,
  989. token_type_ids: Optional[torch.Tensor] = None,
  990. input_mask: Optional[torch.Tensor] = None,
  991. head_mask: Optional[torch.Tensor] = None,
  992. inputs_embeds: Optional[torch.Tensor] = None,
  993. use_mems: Optional[bool] = None,
  994. output_attentions: Optional[bool] = None,
  995. output_hidden_states: Optional[bool] = None,
  996. return_dict: Optional[bool] = None,
  997. **kwargs, # delete after depreciation warning is removed
  998. ) -> Union[tuple, XLNetModelOutput]:
  999. r"""
  1000. mems (`list[torch.FloatTensor]` of length `config.n_layers`):
  1001. Contains pre-computed hidden-states (see `mems` output below) . Can be used to speed up sequential
  1002. decoding. The token ids which have their past given to this model should not be passed as `input_ids` as
  1003. they have already been computed.
  1004. `use_mems` has to be set to `True` to make use of `mems`.
  1005. perm_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length, sequence_length)`, *optional*):
  1006. Mask to indicate the attention pattern for each input token with values selected in `[0, 1]`:
  1007. - if `perm_mask[k, i, j] = 0`, i attend to j in batch k;
  1008. - if `perm_mask[k, i, j] = 1`, i does not attend to j in batch k.
  1009. If not set, each token attends to all the others (full bidirectional attention). Only used during
  1010. pretraining (to define factorization order) or for sequential decoding (generation).
  1011. target_mapping (`torch.FloatTensor` of shape `(batch_size, num_predict, sequence_length)`, *optional*):
  1012. Mask to indicate the output tokens to use. If `target_mapping[k, i, j] = 1`, the i-th predict in batch k is
  1013. on the j-th token. Only used during pretraining for partial prediction or for sequential decoding
  1014. (generation).
  1015. input_mask (`torch.FloatTensor` of shape `batch_size, sequence_length`, *optional*):
  1016. Mask to avoid performing attention on padding token indices. Negative of `attention_mask`, i.e. with 0 for
  1017. real tokens and 1 for padding which is kept for compatibility with the original code base.
  1018. Mask values selected in `[0, 1]`:
  1019. - 1 for tokens that are **masked**,
  1020. - 0 for tokens that are **not masked**.
  1021. You can only uses one of `input_mask` and `attention_mask`.
  1022. use_mems (`bool`, *optional*):
  1023. Whether to use memory states to speed up sequential decoding. If set to `True`, the model will use the hidden
  1024. states from previous forward passes to compute attention, which can significantly improve performance for
  1025. sequential decoding tasks.
  1026. """
  1027. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1028. output_hidden_states = (
  1029. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1030. )
  1031. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1032. if "use_cache" in kwargs:
  1033. warnings.warn(
  1034. "The `use_cache` argument is deprecated and will be removed in a future version, use `use_mems`"
  1035. " instead.",
  1036. FutureWarning,
  1037. )
  1038. use_mems = kwargs["use_cache"]
  1039. if self.training:
  1040. use_mems = use_mems if use_mems is not None else self.config.use_mems_train
  1041. else:
  1042. use_mems = use_mems if use_mems is not None else self.config.use_mems_eval
  1043. # the original code for XLNet uses shapes [len, bsz] with the batch dimension at the end
  1044. # but we want a unified interface in the library with the batch size on the first dimension
  1045. # so we move here the first dimension (batch) to the end
  1046. if input_ids is not None and inputs_embeds is not None:
  1047. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  1048. elif input_ids is not None:
  1049. input_ids = input_ids.transpose(0, 1).contiguous()
  1050. qlen, bsz = input_ids.shape[0], input_ids.shape[1]
  1051. elif inputs_embeds is not None:
  1052. inputs_embeds = inputs_embeds.transpose(0, 1).contiguous()
  1053. qlen, bsz = inputs_embeds.shape[0], inputs_embeds.shape[1]
  1054. else:
  1055. raise ValueError("You have to specify either input_ids or inputs_embeds")
  1056. token_type_ids = token_type_ids.transpose(0, 1).contiguous() if token_type_ids is not None else None
  1057. input_mask = input_mask.transpose(0, 1).contiguous() if input_mask is not None else None
  1058. attention_mask = attention_mask.transpose(0, 1).contiguous() if attention_mask is not None else None
  1059. perm_mask = perm_mask.permute(1, 2, 0).contiguous() if perm_mask is not None else None
  1060. target_mapping = target_mapping.permute(1, 2, 0).contiguous() if target_mapping is not None else None
  1061. mlen = mems[0].shape[0] if mems is not None and mems[0] is not None else 0
  1062. klen = mlen + qlen
  1063. dtype_float = self.dtype
  1064. device = self.device
  1065. # Attention mask
  1066. # causal attention mask
  1067. if self.attn_type == "uni":
  1068. attn_mask = self.create_mask(qlen, mlen)
  1069. attn_mask = attn_mask[:, :, None, None]
  1070. elif self.attn_type == "bi":
  1071. attn_mask = None
  1072. else:
  1073. raise ValueError(f"Unsupported attention type: {self.attn_type}")
  1074. # data mask: input mask & perm mask
  1075. assert input_mask is None or attention_mask is None, "You can only use one of input_mask (uses 1 for padding) "
  1076. "or attention_mask (uses 0 for padding, added for compatibility with BERT). Please choose one."
  1077. if input_mask is None and attention_mask is not None:
  1078. input_mask = 1.0 - attention_mask
  1079. if input_mask is not None and perm_mask is not None:
  1080. data_mask = input_mask[None] + perm_mask
  1081. elif input_mask is not None and perm_mask is None:
  1082. data_mask = input_mask[None]
  1083. elif input_mask is None and perm_mask is not None:
  1084. data_mask = perm_mask
  1085. else:
  1086. data_mask = None
  1087. if data_mask is not None:
  1088. # all mems can be attended to
  1089. if mlen > 0:
  1090. mems_mask = torch.zeros([data_mask.shape[0], mlen, bsz]).to(data_mask)
  1091. data_mask = torch.cat([mems_mask, data_mask], dim=1)
  1092. if attn_mask is None:
  1093. attn_mask = data_mask[:, :, :, None]
  1094. else:
  1095. attn_mask += data_mask[:, :, :, None]
  1096. if attn_mask is not None:
  1097. attn_mask = (attn_mask > 0).to(dtype_float)
  1098. if attn_mask is not None:
  1099. non_tgt_mask = -torch.eye(qlen).to(attn_mask)
  1100. if mlen > 0:
  1101. non_tgt_mask = torch.cat([torch.zeros([qlen, mlen]).to(attn_mask), non_tgt_mask], dim=-1)
  1102. non_tgt_mask = ((attn_mask + non_tgt_mask[:, :, None, None]) > 0).to(attn_mask)
  1103. else:
  1104. non_tgt_mask = None
  1105. # Word embeddings and prepare h & g hidden states
  1106. if inputs_embeds is not None:
  1107. word_emb_k = inputs_embeds
  1108. else:
  1109. word_emb_k = self.word_embedding(input_ids)
  1110. output_h = self.dropout(word_emb_k)
  1111. if target_mapping is not None:
  1112. word_emb_q = self.mask_emb.expand(target_mapping.shape[0], bsz, -1)
  1113. # else: # We removed the inp_q input which was same as target mapping
  1114. # inp_q_ext = inp_q[:, :, None]
  1115. # word_emb_q = inp_q_ext * self.mask_emb + (1 - inp_q_ext) * word_emb_k
  1116. output_g = self.dropout(word_emb_q)
  1117. else:
  1118. output_g = None
  1119. # Segment embedding
  1120. if token_type_ids is not None:
  1121. # Convert `token_type_ids` to one-hot `seg_mat`
  1122. if mlen > 0:
  1123. mem_pad = torch.zeros([mlen, bsz], dtype=torch.long, device=device)
  1124. cat_ids = torch.cat([mem_pad, token_type_ids], dim=0)
  1125. else:
  1126. cat_ids = token_type_ids
  1127. # `1` indicates not in the same segment [qlen x klen x bsz]
  1128. seg_mat = (token_type_ids[:, None] != cat_ids[None, :]).long()
  1129. seg_mat = nn.functional.one_hot(seg_mat, num_classes=2).to(dtype_float)
  1130. else:
  1131. seg_mat = None
  1132. # Positional encoding
  1133. pos_emb = self.relative_positional_encoding(qlen, klen, bsz=bsz)
  1134. pos_emb = pos_emb.to(output_h.device)
  1135. pos_emb = self.dropout(pos_emb)
  1136. # Prepare head mask if needed
  1137. # 1.0 in head_mask indicate we keep the head
  1138. # attention_probs has shape bsz x n_heads x N x N
  1139. # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] (a head_mask for each layer)
  1140. # and head_mask is converted to shape [num_hidden_layers x qlen x klen x bsz x n_head]
  1141. if head_mask is not None:
  1142. if head_mask.dim() == 1:
  1143. head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0).unsqueeze(0)
  1144. head_mask = head_mask.expand(self.n_layer, -1, -1, -1, -1)
  1145. elif head_mask.dim() == 2:
  1146. head_mask = head_mask.unsqueeze(1).unsqueeze(1).unsqueeze(1)
  1147. head_mask = head_mask.to(
  1148. dtype=next(self.parameters()).dtype
  1149. ) # switch to float if need + fp16 compatibility
  1150. else:
  1151. head_mask = [None] * self.n_layer
  1152. new_mems = ()
  1153. if mems is None:
  1154. mems = [None] * len(self.layer)
  1155. attentions = [] if output_attentions else None
  1156. hidden_states = [] if output_hidden_states else None
  1157. for i, layer_module in enumerate(self.layer):
  1158. if use_mems:
  1159. # cache new mems
  1160. new_mems = new_mems + (self.cache_mem(output_h, mems[i]),)
  1161. if output_hidden_states:
  1162. hidden_states.append((output_h, output_g) if output_g is not None else output_h)
  1163. outputs = layer_module(
  1164. output_h,
  1165. output_g,
  1166. attn_mask_h=non_tgt_mask,
  1167. attn_mask_g=attn_mask,
  1168. r=pos_emb,
  1169. seg_mat=seg_mat,
  1170. mems=mems[i],
  1171. target_mapping=target_mapping,
  1172. head_mask=head_mask[i],
  1173. output_attentions=output_attentions,
  1174. )
  1175. output_h, output_g = outputs[:2]
  1176. if output_attentions:
  1177. attentions.append(outputs[2])
  1178. # Add last hidden state
  1179. if output_hidden_states:
  1180. hidden_states.append((output_h, output_g) if output_g is not None else output_h)
  1181. output = self.dropout(output_g if output_g is not None else output_h)
  1182. # Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method)
  1183. output = output.permute(1, 0, 2).contiguous()
  1184. if not use_mems:
  1185. new_mems = None
  1186. if output_hidden_states:
  1187. if output_g is not None:
  1188. hidden_states = tuple(h.permute(1, 0, 2).contiguous() for hs in hidden_states for h in hs)
  1189. else:
  1190. hidden_states = tuple(hs.permute(1, 0, 2).contiguous() for hs in hidden_states)
  1191. if output_attentions:
  1192. if target_mapping is not None:
  1193. # when target_mapping is provided, there are 2-tuple of attentions
  1194. attentions = tuple(
  1195. tuple(att_stream.permute(2, 3, 0, 1).contiguous() for att_stream in t) for t in attentions
  1196. )
  1197. else:
  1198. attentions = tuple(t.permute(2, 3, 0, 1).contiguous() for t in attentions)
  1199. if not return_dict:
  1200. return tuple(v for v in [output, new_mems, hidden_states, attentions] if v is not None)
  1201. return XLNetModelOutput(
  1202. last_hidden_state=output, mems=new_mems, hidden_states=hidden_states, attentions=attentions
  1203. )
  1204. @auto_docstring(
  1205. custom_intro="""
  1206. XLNet Model with a language modeling head on top (linear layer with weights tied to the input embeddings).
  1207. """
  1208. )
  1209. class XLNetLMHeadModel(XLNetPreTrainedModel, GenerationMixin):
  1210. _tied_weights_keys = ["lm_loss.weight"]
  1211. def __init__(self, config):
  1212. super().__init__(config)
  1213. self.attn_type = config.attn_type
  1214. self.same_length = config.same_length
  1215. self.transformer = XLNetModel(config)
  1216. self.lm_loss = nn.Linear(config.d_model, config.vocab_size, bias=True)
  1217. # Initialize weights and apply final processing
  1218. self.post_init()
  1219. def get_output_embeddings(self):
  1220. return self.lm_loss
  1221. def set_output_embeddings(self, new_embeddings):
  1222. self.lm_loss = new_embeddings
  1223. def prepare_inputs_for_generation(self, input_ids, past_key_values=None, use_mems=None, **kwargs):
  1224. # Overwritten -- this model has unique input preparation
  1225. # Add dummy token at the end (no attention on this one)
  1226. effective_batch_size = input_ids.shape[0]
  1227. dummy_token = torch.zeros((effective_batch_size, 1), dtype=torch.long, device=input_ids.device)
  1228. # At every pass, the attention values for the new token and the two last generated tokens
  1229. # are computed, the rest is reloaded from the `past` cache. A purely auto-regressive model would have
  1230. # offset = 1; offset = 2 seems to have slightly better computation.
  1231. offset = 2
  1232. if past_key_values:
  1233. input_ids = torch.cat([input_ids[:, -offset:], dummy_token], dim=1)
  1234. else:
  1235. input_ids = torch.cat([input_ids, dummy_token], dim=1)
  1236. # Build permutation mask so that previous tokens don't see last token
  1237. sequence_length = input_ids.shape[1]
  1238. perm_mask = torch.zeros(
  1239. (effective_batch_size, sequence_length, sequence_length), dtype=torch.float, device=input_ids.device
  1240. )
  1241. perm_mask[:, :, -1] = 1.0
  1242. # We'll only predict the last token
  1243. target_mapping = torch.zeros(
  1244. (effective_batch_size, 1, sequence_length), dtype=torch.float, device=input_ids.device
  1245. )
  1246. target_mapping[:, 0, -1] = 1.0
  1247. model_inputs = {
  1248. "input_ids": input_ids,
  1249. "perm_mask": perm_mask,
  1250. "target_mapping": target_mapping,
  1251. "use_mems": use_mems,
  1252. }
  1253. # if past is defined in model kwargs then use it for faster decoding
  1254. if past_key_values:
  1255. model_inputs["mems"] = tuple(layer_past[:-offset, :, :] for layer_past in past_key_values)
  1256. # Attention mask is computed on the fly on XLNetModel.forward()
  1257. kwargs.pop("attention_mask", None)
  1258. # TODO: Ignoring use_cache should not happen, fixme.
  1259. kwargs.pop("use_cache", None)
  1260. # Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
  1261. for key, value in kwargs.items():
  1262. if key not in model_inputs:
  1263. model_inputs[key] = value
  1264. return model_inputs
  1265. @auto_docstring
  1266. def forward(
  1267. self,
  1268. input_ids: Optional[torch.Tensor] = None,
  1269. attention_mask: Optional[torch.Tensor] = None,
  1270. mems: Optional[torch.Tensor] = None,
  1271. perm_mask: Optional[torch.Tensor] = None,
  1272. target_mapping: Optional[torch.Tensor] = None,
  1273. token_type_ids: Optional[torch.Tensor] = None,
  1274. input_mask: Optional[torch.Tensor] = None,
  1275. head_mask: Optional[torch.Tensor] = None,
  1276. inputs_embeds: Optional[torch.Tensor] = None,
  1277. labels: Optional[torch.Tensor] = None,
  1278. use_mems: Optional[bool] = None,
  1279. output_attentions: Optional[bool] = None,
  1280. output_hidden_states: Optional[bool] = None,
  1281. return_dict: Optional[bool] = None,
  1282. **kwargs, # delete when `use_cache` is removed in XLNetModel
  1283. ) -> Union[tuple, XLNetLMHeadModelOutput]:
  1284. r"""
  1285. mems (`list[torch.FloatTensor]` of length `config.n_layers`):
  1286. Contains pre-computed hidden-states (see `mems` output below) . Can be used to speed up sequential
  1287. decoding. The token ids which have their past given to this model should not be passed as `input_ids` as
  1288. they have already been computed.
  1289. `use_mems` has to be set to `True` to make use of `mems`.
  1290. perm_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length, sequence_length)`, *optional*):
  1291. Mask to indicate the attention pattern for each input token with values selected in `[0, 1]`:
  1292. - if `perm_mask[k, i, j] = 0`, i attend to j in batch k;
  1293. - if `perm_mask[k, i, j] = 1`, i does not attend to j in batch k.
  1294. If not set, each token attends to all the others (full bidirectional attention). Only used during
  1295. pretraining (to define factorization order) or for sequential decoding (generation).
  1296. target_mapping (`torch.FloatTensor` of shape `(batch_size, num_predict, sequence_length)`, *optional*):
  1297. Mask to indicate the output tokens to use. If `target_mapping[k, i, j] = 1`, the i-th predict in batch k is
  1298. on the j-th token. Only used during pretraining for partial prediction or for sequential decoding
  1299. (generation).
  1300. input_mask (`torch.FloatTensor` of shape `batch_size, sequence_length`, *optional*):
  1301. Mask to avoid performing attention on padding token indices. Negative of `attention_mask`, i.e. with 0 for
  1302. real tokens and 1 for padding which is kept for compatibility with the original code base.
  1303. Mask values selected in `[0, 1]`:
  1304. - 1 for tokens that are **masked**,
  1305. - 0 for tokens that are **not masked**.
  1306. You can only uses one of `input_mask` and `attention_mask`.
  1307. labels (`torch.LongTensor` of shape `(batch_size, num_predict)`, *optional*):
  1308. Labels for masked language modeling. `num_predict` corresponds to `target_mapping.shape[1]`. If
  1309. `target_mapping` is `None`, then `num_predict` corresponds to `sequence_length`.
  1310. The labels should correspond to the masked input words that should be predicted and depends on
  1311. `target_mapping`. Note in order to perform standard auto-regressive language modeling a *<mask>* token has
  1312. to be added to the `input_ids` (see the `prepare_inputs_for_generation` function and examples below)
  1313. Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` are ignored, the loss
  1314. is only computed for labels in `[0, ..., config.vocab_size]`
  1315. use_mems (`bool`, *optional*):
  1316. Whether to use memory states to speed up sequential decoding. If set to `True`, the model will use the hidden
  1317. states from previous forward passes to compute attention, which can significantly improve performance for
  1318. sequential decoding tasks.
  1319. Examples:
  1320. ```python
  1321. >>> from transformers import AutoTokenizer, XLNetLMHeadModel
  1322. >>> import torch
  1323. >>> tokenizer = AutoTokenizer.from_pretrained("xlnet/xlnet-large-cased")
  1324. >>> model = XLNetLMHeadModel.from_pretrained("xlnet/xlnet-large-cased")
  1325. >>> # We show how to setup inputs to predict a next token using a bi-directional context.
  1326. >>> input_ids = torch.tensor(
  1327. ... tokenizer.encode("Hello, my dog is very <mask>", add_special_tokens=False)
  1328. ... ).unsqueeze(
  1329. ... 0
  1330. ... ) # We will predict the masked token
  1331. >>> perm_mask = torch.zeros((1, input_ids.shape[1], input_ids.shape[1]), dtype=torch.float)
  1332. >>> perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token
  1333. >>> target_mapping = torch.zeros(
  1334. ... (1, 1, input_ids.shape[1]), dtype=torch.float
  1335. ... ) # Shape [1, 1, seq_length] => let's predict one token
  1336. >>> target_mapping[
  1337. ... 0, 0, -1
  1338. ... ] = 1.0 # Our first (and only) prediction will be the last token of the sequence (the masked token)
  1339. >>> outputs = model(input_ids, perm_mask=perm_mask, target_mapping=target_mapping)
  1340. >>> next_token_logits = outputs[
  1341. ... 0
  1342. ... ] # Output has shape [target_mapping.size(0), target_mapping.size(1), config.vocab_size]
  1343. >>> # The same way can the XLNetLMHeadModel be used to be trained by standard auto-regressive language modeling.
  1344. >>> input_ids = torch.tensor(
  1345. ... tokenizer.encode("Hello, my dog is very <mask>", add_special_tokens=False)
  1346. ... ).unsqueeze(
  1347. ... 0
  1348. ... ) # We will predict the masked token
  1349. >>> labels = torch.tensor(tokenizer.encode("cute", add_special_tokens=False)).unsqueeze(0)
  1350. >>> assert labels.shape[0] == 1, "only one word will be predicted"
  1351. >>> perm_mask = torch.zeros((1, input_ids.shape[1], input_ids.shape[1]), dtype=torch.float)
  1352. >>> perm_mask[
  1353. ... :, :, -1
  1354. ... ] = 1.0 # Previous tokens don't see last token as is done in standard auto-regressive lm training
  1355. >>> target_mapping = torch.zeros(
  1356. ... (1, 1, input_ids.shape[1]), dtype=torch.float
  1357. ... ) # Shape [1, 1, seq_length] => let's predict one token
  1358. >>> target_mapping[
  1359. ... 0, 0, -1
  1360. ... ] = 1.0 # Our first (and only) prediction will be the last token of the sequence (the masked token)
  1361. >>> outputs = model(input_ids, perm_mask=perm_mask, target_mapping=target_mapping, labels=labels)
  1362. >>> loss = outputs.loss
  1363. >>> next_token_logits = (
  1364. ... outputs.logits
  1365. ... ) # Logits have shape [target_mapping.size(0), target_mapping.size(1), config.vocab_size]
  1366. ```"""
  1367. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1368. transformer_outputs = self.transformer(
  1369. input_ids,
  1370. attention_mask=attention_mask,
  1371. mems=mems,
  1372. perm_mask=perm_mask,
  1373. target_mapping=target_mapping,
  1374. token_type_ids=token_type_ids,
  1375. input_mask=input_mask,
  1376. head_mask=head_mask,
  1377. inputs_embeds=inputs_embeds,
  1378. use_mems=use_mems,
  1379. output_attentions=output_attentions,
  1380. output_hidden_states=output_hidden_states,
  1381. return_dict=return_dict,
  1382. **kwargs,
  1383. )
  1384. logits = self.lm_loss(transformer_outputs[0])
  1385. loss = None
  1386. if labels is not None:
  1387. # Flatten the tokens
  1388. loss_fct = CrossEntropyLoss()
  1389. loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
  1390. if not return_dict:
  1391. output = (logits,) + transformer_outputs[1:]
  1392. return ((loss,) + output) if loss is not None else output
  1393. return XLNetLMHeadModelOutput(
  1394. loss=loss,
  1395. logits=logits,
  1396. mems=transformer_outputs.mems,
  1397. hidden_states=transformer_outputs.hidden_states,
  1398. attentions=transformer_outputs.attentions,
  1399. )
  1400. @staticmethod
  1401. def _reorder_cache(mems: list[torch.Tensor], beam_idx: torch.Tensor) -> list[torch.Tensor]:
  1402. """
  1403. This function is used to re-order the `mems` cache if [`~PreTrainedModel.beam_search`] or
  1404. [`~PreTrainedModel.beam_sample`] is called. This is required to match `mems` with the correct beam_idx at every
  1405. generation step.
  1406. """
  1407. return [layer_past.index_select(1, beam_idx.to(layer_past.device)) for layer_past in mems]
  1408. @auto_docstring(
  1409. custom_intro="""
  1410. XLNet Model with a sequence classification/regression head on top (a linear layer on top of the pooled output) e.g.
  1411. for GLUE tasks.
  1412. """
  1413. )
  1414. class XLNetForSequenceClassification(XLNetPreTrainedModel):
  1415. def __init__(self, config):
  1416. super().__init__(config)
  1417. self.num_labels = config.num_labels
  1418. self.config = config
  1419. self.transformer = XLNetModel(config)
  1420. self.sequence_summary = XLNetSequenceSummary(config)
  1421. self.logits_proj = nn.Linear(config.d_model, config.num_labels)
  1422. # Initialize weights and apply final processing
  1423. self.post_init()
  1424. @auto_docstring
  1425. def forward(
  1426. self,
  1427. input_ids: Optional[torch.Tensor] = None,
  1428. attention_mask: Optional[torch.Tensor] = None,
  1429. mems: Optional[torch.Tensor] = None,
  1430. perm_mask: Optional[torch.Tensor] = None,
  1431. target_mapping: Optional[torch.Tensor] = None,
  1432. token_type_ids: Optional[torch.Tensor] = None,
  1433. input_mask: Optional[torch.Tensor] = None,
  1434. head_mask: Optional[torch.Tensor] = None,
  1435. inputs_embeds: Optional[torch.Tensor] = None,
  1436. labels: Optional[torch.Tensor] = None,
  1437. use_mems: Optional[bool] = None,
  1438. output_attentions: Optional[bool] = None,
  1439. output_hidden_states: Optional[bool] = None,
  1440. return_dict: Optional[bool] = None,
  1441. **kwargs, # delete when `use_cache` is removed in XLNetModel
  1442. ) -> Union[tuple, XLNetForSequenceClassificationOutput]:
  1443. r"""
  1444. mems (`list[torch.FloatTensor]` of length `config.n_layers`):
  1445. Contains pre-computed hidden-states (see `mems` output below) . Can be used to speed up sequential
  1446. decoding. The token ids which have their past given to this model should not be passed as `input_ids` as
  1447. they have already been computed.
  1448. `use_mems` has to be set to `True` to make use of `mems`.
  1449. perm_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length, sequence_length)`, *optional*):
  1450. Mask to indicate the attention pattern for each input token with values selected in `[0, 1]`:
  1451. - if `perm_mask[k, i, j] = 0`, i attend to j in batch k;
  1452. - if `perm_mask[k, i, j] = 1`, i does not attend to j in batch k.
  1453. If not set, each token attends to all the others (full bidirectional attention). Only used during
  1454. pretraining (to define factorization order) or for sequential decoding (generation).
  1455. target_mapping (`torch.FloatTensor` of shape `(batch_size, num_predict, sequence_length)`, *optional*):
  1456. Mask to indicate the output tokens to use. If `target_mapping[k, i, j] = 1`, the i-th predict in batch k is
  1457. on the j-th token. Only used during pretraining for partial prediction or for sequential decoding
  1458. (generation).
  1459. input_mask (`torch.FloatTensor` of shape `batch_size, sequence_length`, *optional*):
  1460. Mask to avoid performing attention on padding token indices. Negative of `attention_mask`, i.e. with 0 for
  1461. real tokens and 1 for padding which is kept for compatibility with the original code base.
  1462. Mask values selected in `[0, 1]`:
  1463. - 1 for tokens that are **masked**,
  1464. - 0 for tokens that are **not masked**.
  1465. You can only uses one of `input_mask` and `attention_mask`.
  1466. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1467. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1468. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  1469. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1470. use_mems (`bool`, *optional*):
  1471. Whether to use memory states to speed up sequential decoding. If set to `True`, the model will use the hidden
  1472. states from previous forward passes to compute attention, which can significantly improve performance for
  1473. sequential decoding tasks.
  1474. """
  1475. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1476. transformer_outputs = self.transformer(
  1477. input_ids,
  1478. attention_mask=attention_mask,
  1479. mems=mems,
  1480. perm_mask=perm_mask,
  1481. target_mapping=target_mapping,
  1482. token_type_ids=token_type_ids,
  1483. input_mask=input_mask,
  1484. head_mask=head_mask,
  1485. inputs_embeds=inputs_embeds,
  1486. use_mems=use_mems,
  1487. output_attentions=output_attentions,
  1488. output_hidden_states=output_hidden_states,
  1489. return_dict=return_dict,
  1490. **kwargs,
  1491. )
  1492. output = transformer_outputs[0]
  1493. output = self.sequence_summary(output)
  1494. logits = self.logits_proj(output)
  1495. loss = None
  1496. if labels is not None:
  1497. if self.config.problem_type is None:
  1498. if self.num_labels == 1:
  1499. self.config.problem_type = "regression"
  1500. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  1501. self.config.problem_type = "single_label_classification"
  1502. else:
  1503. self.config.problem_type = "multi_label_classification"
  1504. if self.config.problem_type == "regression":
  1505. loss_fct = MSELoss()
  1506. if self.num_labels == 1:
  1507. loss = loss_fct(logits.squeeze(), labels.squeeze())
  1508. else:
  1509. loss = loss_fct(logits, labels)
  1510. elif self.config.problem_type == "single_label_classification":
  1511. loss_fct = CrossEntropyLoss()
  1512. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  1513. elif self.config.problem_type == "multi_label_classification":
  1514. loss_fct = BCEWithLogitsLoss()
  1515. loss = loss_fct(logits, labels)
  1516. if not return_dict:
  1517. output = (logits,) + transformer_outputs[1:]
  1518. return ((loss,) + output) if loss is not None else output
  1519. return XLNetForSequenceClassificationOutput(
  1520. loss=loss,
  1521. logits=logits,
  1522. mems=transformer_outputs.mems,
  1523. hidden_states=transformer_outputs.hidden_states,
  1524. attentions=transformer_outputs.attentions,
  1525. )
  1526. @auto_docstring
  1527. class XLNetForTokenClassification(XLNetPreTrainedModel):
  1528. def __init__(self, config):
  1529. super().__init__(config)
  1530. self.num_labels = config.num_labels
  1531. self.transformer = XLNetModel(config)
  1532. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  1533. # Initialize weights and apply final processing
  1534. self.post_init()
  1535. @auto_docstring
  1536. def forward(
  1537. self,
  1538. input_ids: Optional[torch.Tensor] = None,
  1539. attention_mask: Optional[torch.Tensor] = None,
  1540. mems: Optional[torch.Tensor] = None,
  1541. perm_mask: Optional[torch.Tensor] = None,
  1542. target_mapping: Optional[torch.Tensor] = None,
  1543. token_type_ids: Optional[torch.Tensor] = None,
  1544. input_mask: Optional[torch.Tensor] = None,
  1545. head_mask: Optional[torch.Tensor] = None,
  1546. inputs_embeds: Optional[torch.Tensor] = None,
  1547. labels: Optional[torch.Tensor] = None,
  1548. use_mems: Optional[bool] = None,
  1549. output_attentions: Optional[bool] = None,
  1550. output_hidden_states: Optional[bool] = None,
  1551. return_dict: Optional[bool] = None,
  1552. **kwargs, # delete when `use_cache` is removed in XLNetModel
  1553. ) -> Union[tuple, XLNetForTokenClassificationOutput]:
  1554. r"""
  1555. mems (`list[torch.FloatTensor]` of length `config.n_layers`):
  1556. Contains pre-computed hidden-states (see `mems` output below) . Can be used to speed up sequential
  1557. decoding. The token ids which have their past given to this model should not be passed as `input_ids` as
  1558. they have already been computed.
  1559. `use_mems` has to be set to `True` to make use of `mems`.
  1560. perm_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length, sequence_length)`, *optional*):
  1561. Mask to indicate the attention pattern for each input token with values selected in `[0, 1]`:
  1562. - if `perm_mask[k, i, j] = 0`, i attend to j in batch k;
  1563. - if `perm_mask[k, i, j] = 1`, i does not attend to j in batch k.
  1564. If not set, each token attends to all the others (full bidirectional attention). Only used during
  1565. pretraining (to define factorization order) or for sequential decoding (generation).
  1566. target_mapping (`torch.FloatTensor` of shape `(batch_size, num_predict, sequence_length)`, *optional*):
  1567. Mask to indicate the output tokens to use. If `target_mapping[k, i, j] = 1`, the i-th predict in batch k is
  1568. on the j-th token. Only used during pretraining for partial prediction or for sequential decoding
  1569. (generation).
  1570. input_mask (`torch.FloatTensor` of shape `batch_size, sequence_length`, *optional*):
  1571. Mask to avoid performing attention on padding token indices. Negative of `attention_mask`, i.e. with 0 for
  1572. real tokens and 1 for padding which is kept for compatibility with the original code base.
  1573. Mask values selected in `[0, 1]`:
  1574. - 1 for tokens that are **masked**,
  1575. - 0 for tokens that are **not masked**.
  1576. You can only uses one of `input_mask` and `attention_mask`.
  1577. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1578. Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]`
  1579. where *num_choices* is the size of the second dimension of the input tensors. (see *input_ids* above)
  1580. use_mems (`bool`, *optional*):
  1581. Whether to use memory states to speed up sequential decoding. If set to `True`, the model will use the hidden
  1582. states from previous forward passes to compute attention, which can significantly improve performance for
  1583. sequential decoding tasks.emory states to speed up sequential decoding. If set to `True`, the model will use the hidden
  1584. states from previous forward passes to compute attention, which can significantly improve performance for
  1585. sequential decoding tasks.
  1586. """
  1587. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1588. outputs = self.transformer(
  1589. input_ids,
  1590. attention_mask=attention_mask,
  1591. mems=mems,
  1592. perm_mask=perm_mask,
  1593. target_mapping=target_mapping,
  1594. token_type_ids=token_type_ids,
  1595. input_mask=input_mask,
  1596. head_mask=head_mask,
  1597. inputs_embeds=inputs_embeds,
  1598. use_mems=use_mems,
  1599. output_attentions=output_attentions,
  1600. output_hidden_states=output_hidden_states,
  1601. return_dict=return_dict,
  1602. )
  1603. sequence_output = outputs[0]
  1604. logits = self.classifier(sequence_output)
  1605. loss = None
  1606. if labels is not None:
  1607. loss_fct = CrossEntropyLoss()
  1608. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  1609. if not return_dict:
  1610. output = (logits,) + outputs[1:]
  1611. return ((loss,) + output) if loss is not None else output
  1612. return XLNetForTokenClassificationOutput(
  1613. loss=loss,
  1614. logits=logits,
  1615. mems=outputs.mems,
  1616. hidden_states=outputs.hidden_states,
  1617. attentions=outputs.attentions,
  1618. )
  1619. @auto_docstring
  1620. class XLNetForMultipleChoice(XLNetPreTrainedModel):
  1621. def __init__(self, config):
  1622. super().__init__(config)
  1623. self.transformer = XLNetModel(config)
  1624. self.sequence_summary = XLNetSequenceSummary(config)
  1625. self.logits_proj = nn.Linear(config.d_model, 1)
  1626. # Initialize weights and apply final processing
  1627. self.post_init()
  1628. @auto_docstring
  1629. def forward(
  1630. self,
  1631. input_ids: Optional[torch.Tensor] = None,
  1632. token_type_ids: Optional[torch.Tensor] = None,
  1633. input_mask: Optional[torch.Tensor] = None,
  1634. attention_mask: Optional[torch.Tensor] = None,
  1635. mems: Optional[torch.Tensor] = None,
  1636. perm_mask: Optional[torch.Tensor] = None,
  1637. target_mapping: Optional[torch.Tensor] = None,
  1638. head_mask: Optional[torch.Tensor] = None,
  1639. inputs_embeds: Optional[torch.Tensor] = None,
  1640. labels: Optional[torch.Tensor] = None,
  1641. use_mems: Optional[bool] = None,
  1642. output_attentions: Optional[bool] = None,
  1643. output_hidden_states: Optional[bool] = None,
  1644. return_dict: Optional[bool] = None,
  1645. **kwargs, # delete when `use_cache` is removed in XLNetModel
  1646. ) -> Union[tuple, XLNetForMultipleChoiceOutput]:
  1647. r"""
  1648. input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
  1649. Indices of input sequence tokens in the vocabulary.
  1650. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1651. [`PreTrainedTokenizer.__call__`] for details.
  1652. [What are input IDs?](../glossary#input-ids)
  1653. token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  1654. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
  1655. 1]`:
  1656. - 0 corresponds to a *sentence A* token,
  1657. - 1 corresponds to a *sentence B* token.
  1658. [What are token type IDs?](../glossary#token-type-ids)
  1659. input_mask (`torch.FloatTensor` of shape `batch_size, num_choices, sequence_length`, *optional*):
  1660. Mask to avoid performing attention on padding token indices. Negative of `attention_mask`, i.e. with 0 for
  1661. real tokens and 1 for padding which is kept for compatibility with the original code base.
  1662. Mask values selected in `[0, 1]`:
  1663. - 1 for tokens that are **masked**,
  1664. - 0 for tokens that are **not masked**.
  1665. You can only uses one of `input_mask` and `attention_mask`.
  1666. mems (`list[torch.FloatTensor]` of length `config.n_layers`):
  1667. Contains pre-computed hidden-states (see `mems` output below) . Can be used to speed up sequential
  1668. decoding. The token ids which have their past given to this model should not be passed as `input_ids` as
  1669. they have already been computed.
  1670. `use_mems` has to be set to `True` to make use of `mems`.
  1671. perm_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length, sequence_length)`, *optional*):
  1672. Mask to indicate the attention pattern for each input token with values selected in `[0, 1]`:
  1673. - if `perm_mask[k, i, j] = 0`, i attend to j in batch k;
  1674. - if `perm_mask[k, i, j] = 1`, i does not attend to j in batch k.
  1675. If not set, each token attends to all the others (full bidirectional attention). Only used during
  1676. pretraining (to define factorization order) or for sequential decoding (generation).
  1677. target_mapping (`torch.FloatTensor` of shape `(batch_size, num_predict, sequence_length)`, *optional*):
  1678. Mask to indicate the output tokens to use. If `target_mapping[k, i, j] = 1`, the i-th predict in batch k is
  1679. on the j-th token. Only used during pretraining for partial prediction or for sequential decoding
  1680. (generation).
  1681. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
  1682. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  1683. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  1684. model's internal embedding lookup matrix.
  1685. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1686. Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
  1687. use_mems (`bool`, *optional*):
  1688. Whether to use memory states to speed up sequential decoding. If set to `True`, the model will use the hidden
  1689. states from previous forward passes to compute attention, which can significantly improve performance for
  1690. sequential decoding tasks.
  1691. """
  1692. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1693. num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
  1694. flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
  1695. flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
  1696. flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
  1697. flat_input_mask = input_mask.view(-1, input_mask.size(-1)) if input_mask is not None else None
  1698. flat_inputs_embeds = (
  1699. inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
  1700. if inputs_embeds is not None
  1701. else None
  1702. )
  1703. transformer_outputs = self.transformer(
  1704. flat_input_ids,
  1705. token_type_ids=flat_token_type_ids,
  1706. input_mask=flat_input_mask,
  1707. attention_mask=flat_attention_mask,
  1708. mems=mems,
  1709. perm_mask=perm_mask,
  1710. target_mapping=target_mapping,
  1711. head_mask=head_mask,
  1712. inputs_embeds=flat_inputs_embeds,
  1713. use_mems=use_mems,
  1714. output_attentions=output_attentions,
  1715. output_hidden_states=output_hidden_states,
  1716. return_dict=return_dict,
  1717. **kwargs,
  1718. )
  1719. output = transformer_outputs[0]
  1720. output = self.sequence_summary(output)
  1721. logits = self.logits_proj(output)
  1722. reshaped_logits = logits.view(-1, num_choices)
  1723. loss = None
  1724. if labels is not None:
  1725. loss_fct = CrossEntropyLoss()
  1726. loss = loss_fct(reshaped_logits, labels.view(-1))
  1727. if not return_dict:
  1728. output = (reshaped_logits,) + transformer_outputs[1:]
  1729. return ((loss,) + output) if loss is not None else output
  1730. return XLNetForMultipleChoiceOutput(
  1731. loss=loss,
  1732. logits=reshaped_logits,
  1733. mems=transformer_outputs.mems,
  1734. hidden_states=transformer_outputs.hidden_states,
  1735. attentions=transformer_outputs.attentions,
  1736. )
  1737. @auto_docstring(
  1738. custom_intro="""
  1739. XLNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
  1740. layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
  1741. """
  1742. )
  1743. class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
  1744. def __init__(self, config):
  1745. super().__init__(config)
  1746. self.num_labels = config.num_labels
  1747. self.transformer = XLNetModel(config)
  1748. self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
  1749. # Initialize weights and apply final processing
  1750. self.post_init()
  1751. @auto_docstring
  1752. def forward(
  1753. self,
  1754. input_ids: Optional[torch.Tensor] = None,
  1755. attention_mask: Optional[torch.Tensor] = None,
  1756. mems: Optional[torch.Tensor] = None,
  1757. perm_mask: Optional[torch.Tensor] = None,
  1758. target_mapping: Optional[torch.Tensor] = None,
  1759. token_type_ids: Optional[torch.Tensor] = None,
  1760. input_mask: Optional[torch.Tensor] = None,
  1761. head_mask: Optional[torch.Tensor] = None,
  1762. inputs_embeds: Optional[torch.Tensor] = None,
  1763. start_positions: Optional[torch.Tensor] = None,
  1764. end_positions: Optional[torch.Tensor] = None,
  1765. use_mems: Optional[bool] = None,
  1766. output_attentions: Optional[bool] = None,
  1767. output_hidden_states: Optional[bool] = None,
  1768. return_dict: Optional[bool] = None,
  1769. **kwargs, # delete when `use_cache` is removed in XLNetModel
  1770. ) -> Union[tuple, XLNetForQuestionAnsweringSimpleOutput]:
  1771. r"""
  1772. mems (`list[torch.FloatTensor]` of length `config.n_layers`):
  1773. Contains pre-computed hidden-states (see `mems` output below) . Can be used to speed up sequential
  1774. decoding. The token ids which have their past given to this model should not be passed as `input_ids` as
  1775. they have already been computed.
  1776. `use_mems` has to be set to `True` to make use of `mems`.
  1777. perm_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length, sequence_length)`, *optional*):
  1778. Mask to indicate the attention pattern for each input token with values selected in `[0, 1]`:
  1779. - if `perm_mask[k, i, j] = 0`, i attend to j in batch k;
  1780. - if `perm_mask[k, i, j] = 1`, i does not attend to j in batch k.
  1781. If not set, each token attends to all the others (full bidirectional attention). Only used during
  1782. pretraining (to define factorization order) or for sequential decoding (generation).
  1783. target_mapping (`torch.FloatTensor` of shape `(batch_size, num_predict, sequence_length)`, *optional*):
  1784. Mask to indicate the output tokens to use. If `target_mapping[k, i, j] = 1`, the i-th predict in batch k is
  1785. on the j-th token. Only used during pretraining for partial prediction or for sequential decoding
  1786. (generation).
  1787. input_mask (`torch.FloatTensor` of shape `batch_size, sequence_length`, *optional*):
  1788. Mask to avoid performing attention on padding token indices. Negative of `attention_mask`, i.e. with 0 for
  1789. real tokens and 1 for padding which is kept for compatibility with the original code base.
  1790. Mask values selected in `[0, 1]`:
  1791. - 1 for tokens that are **masked**,
  1792. - 0 for tokens that are **not masked**.
  1793. You can only uses one of `input_mask` and `attention_mask`.
  1794. use_mems (`bool`, *optional*):
  1795. Whether to use memory states to speed up sequential decoding. If set to `True`, the model will use the hidden
  1796. states from previous forward passes to compute attention, which can significantly improve performance for
  1797. sequential decoding tasks.
  1798. """
  1799. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1800. outputs = self.transformer(
  1801. input_ids,
  1802. attention_mask=attention_mask,
  1803. mems=mems,
  1804. perm_mask=perm_mask,
  1805. target_mapping=target_mapping,
  1806. token_type_ids=token_type_ids,
  1807. input_mask=input_mask,
  1808. head_mask=head_mask,
  1809. inputs_embeds=inputs_embeds,
  1810. use_mems=use_mems,
  1811. output_attentions=output_attentions,
  1812. output_hidden_states=output_hidden_states,
  1813. return_dict=return_dict,
  1814. **kwargs,
  1815. )
  1816. sequence_output = outputs[0]
  1817. logits = self.qa_outputs(sequence_output)
  1818. start_logits, end_logits = logits.split(1, dim=-1)
  1819. start_logits = start_logits.squeeze(-1).contiguous()
  1820. end_logits = end_logits.squeeze(-1).contiguous()
  1821. total_loss = None
  1822. if start_positions is not None and end_positions is not None:
  1823. # If we are on multi-GPU, split add a dimension
  1824. if len(start_positions.size()) > 1:
  1825. start_positions = start_positions.squeeze(-1)
  1826. if len(end_positions.size()) > 1:
  1827. end_positions = end_positions.squeeze(-1)
  1828. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  1829. ignored_index = start_logits.size(1)
  1830. start_positions = start_positions.clamp(0, ignored_index)
  1831. end_positions = end_positions.clamp(0, ignored_index)
  1832. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  1833. start_loss = loss_fct(start_logits, start_positions)
  1834. end_loss = loss_fct(end_logits, end_positions)
  1835. total_loss = (start_loss + end_loss) / 2
  1836. if not return_dict:
  1837. output = (start_logits, end_logits) + outputs[1:]
  1838. return ((total_loss,) + output) if total_loss is not None else output
  1839. return XLNetForQuestionAnsweringSimpleOutput(
  1840. loss=total_loss,
  1841. start_logits=start_logits,
  1842. end_logits=end_logits,
  1843. mems=outputs.mems,
  1844. hidden_states=outputs.hidden_states,
  1845. attentions=outputs.attentions,
  1846. )
  1847. @auto_docstring
  1848. class XLNetForQuestionAnswering(XLNetPreTrainedModel):
  1849. def __init__(self, config):
  1850. super().__init__(config)
  1851. self.start_n_top = config.start_n_top
  1852. self.end_n_top = config.end_n_top
  1853. self.transformer = XLNetModel(config)
  1854. self.start_logits = XLNetPoolerStartLogits(config)
  1855. self.end_logits = XLNetPoolerEndLogits(config)
  1856. self.answer_class = XLNetPoolerAnswerClass(config)
  1857. # Initialize weights and apply final processing
  1858. self.post_init()
  1859. @auto_docstring
  1860. def forward(
  1861. self,
  1862. input_ids: Optional[torch.Tensor] = None,
  1863. attention_mask: Optional[torch.Tensor] = None,
  1864. mems: Optional[torch.Tensor] = None,
  1865. perm_mask: Optional[torch.Tensor] = None,
  1866. target_mapping: Optional[torch.Tensor] = None,
  1867. token_type_ids: Optional[torch.Tensor] = None,
  1868. input_mask: Optional[torch.Tensor] = None,
  1869. head_mask: Optional[torch.Tensor] = None,
  1870. inputs_embeds: Optional[torch.Tensor] = None,
  1871. start_positions: Optional[torch.Tensor] = None,
  1872. end_positions: Optional[torch.Tensor] = None,
  1873. is_impossible: Optional[torch.Tensor] = None,
  1874. cls_index: Optional[torch.Tensor] = None,
  1875. p_mask: Optional[torch.Tensor] = None,
  1876. use_mems: Optional[bool] = None,
  1877. output_attentions: Optional[bool] = None,
  1878. output_hidden_states: Optional[bool] = None,
  1879. return_dict: Optional[bool] = None,
  1880. **kwargs, # delete when `use_cache` is removed in XLNetModel
  1881. ) -> Union[tuple, XLNetForQuestionAnsweringOutput]:
  1882. r"""
  1883. mems (`list[torch.FloatTensor]` of length `config.n_layers`):
  1884. Contains pre-computed hidden-states (see `mems` output below) . Can be used to speed up sequential
  1885. decoding. The token ids which have their past given to this model should not be passed as `input_ids` as
  1886. they have already been computed.
  1887. `use_mems` has to be set to `True` to make use of `mems`.
  1888. perm_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length, sequence_length)`, *optional*):
  1889. Mask to indicate the attention pattern for each input token with values selected in `[0, 1]`:
  1890. - if `perm_mask[k, i, j] = 0`, i attend to j in batch k;
  1891. - if `perm_mask[k, i, j] = 1`, i does not attend to j in batch k.
  1892. If not set, each token attends to all the others (full bidirectional attention). Only used during
  1893. pretraining (to define factorization order) or for sequential decoding (generation).
  1894. target_mapping (`torch.FloatTensor` of shape `(batch_size, num_predict, sequence_length)`, *optional*):
  1895. Mask to indicate the output tokens to use. If `target_mapping[k, i, j] = 1`, the i-th predict in batch k is
  1896. on the j-th token. Only used during pretraining for partial prediction or for sequential decoding
  1897. (generation).
  1898. input_mask (`torch.FloatTensor` of shape `batch_size, sequence_length`, *optional*):
  1899. Mask to avoid performing attention on padding token indices. Negative of `attention_mask`, i.e. with 0 for
  1900. real tokens and 1 for padding which is kept for compatibility with the original code base.
  1901. Mask values selected in `[0, 1]`:
  1902. - 1 for tokens that are **masked**,
  1903. - 0 for tokens that are **not masked**.
  1904. You can only uses one of `input_mask` and `attention_mask`.
  1905. is_impossible (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1906. Labels whether a question has an answer or no answer (SQuAD 2.0)
  1907. cls_index (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1908. Labels for position (index) of the classification token to use as input for computing plausibility of the
  1909. answer.
  1910. p_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1911. Optional mask of tokens which can't be in answers (e.g. [CLS], [PAD], ...). 1.0 means token should be
  1912. masked. 0.0 mean token is not masked.
  1913. use_mems (`bool`, *optional*):
  1914. Whether to use memory states to speed up sequential decoding. If set to `True`, the model will use the hidden
  1915. states from previous forward passes to compute attention, which can significantly improve performance for
  1916. sequential decoding tasks.
  1917. Example:
  1918. ```python
  1919. >>> from transformers import AutoTokenizer, XLNetForQuestionAnswering
  1920. >>> import torch
  1921. >>> tokenizer = AutoTokenizer.from_pretrained("xlnet/xlnet-base-cased")
  1922. >>> model = XLNetForQuestionAnswering.from_pretrained("xlnet/xlnet-base-cased")
  1923. >>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(
  1924. ... 0
  1925. ... ) # Batch size 1
  1926. >>> start_positions = torch.tensor([1])
  1927. >>> end_positions = torch.tensor([3])
  1928. >>> outputs = model(input_ids, start_positions=start_positions, end_positions=end_positions)
  1929. >>> loss = outputs.loss
  1930. ```"""
  1931. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1932. transformer_outputs = self.transformer(
  1933. input_ids,
  1934. attention_mask=attention_mask,
  1935. mems=mems,
  1936. perm_mask=perm_mask,
  1937. target_mapping=target_mapping,
  1938. token_type_ids=token_type_ids,
  1939. input_mask=input_mask,
  1940. head_mask=head_mask,
  1941. inputs_embeds=inputs_embeds,
  1942. use_mems=use_mems,
  1943. output_attentions=output_attentions,
  1944. output_hidden_states=output_hidden_states,
  1945. return_dict=return_dict,
  1946. **kwargs,
  1947. )
  1948. hidden_states = transformer_outputs[0]
  1949. start_logits = self.start_logits(hidden_states, p_mask=p_mask)
  1950. outputs = transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it
  1951. if start_positions is not None and end_positions is not None:
  1952. # If we are on multi-GPU, let's remove the dimension added by batch splitting
  1953. for x in (start_positions, end_positions, cls_index, is_impossible):
  1954. if x is not None and x.dim() > 1:
  1955. x.squeeze_(-1)
  1956. # during training, compute the end logits based on the ground truth of the start position
  1957. end_logits = self.end_logits(hidden_states, start_positions=start_positions, p_mask=p_mask)
  1958. loss_fct = CrossEntropyLoss()
  1959. start_loss = loss_fct(start_logits, start_positions)
  1960. end_loss = loss_fct(end_logits, end_positions)
  1961. total_loss = (start_loss + end_loss) / 2
  1962. if cls_index is not None and is_impossible is not None:
  1963. # Predict answerability from the representation of CLS and START
  1964. cls_logits = self.answer_class(hidden_states, start_positions=start_positions, cls_index=cls_index)
  1965. loss_fct_cls = nn.BCEWithLogitsLoss()
  1966. cls_loss = loss_fct_cls(cls_logits, is_impossible)
  1967. # note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss
  1968. total_loss += cls_loss * 0.5
  1969. if not return_dict:
  1970. return (total_loss,) + transformer_outputs[1:]
  1971. else:
  1972. return XLNetForQuestionAnsweringOutput(
  1973. loss=total_loss,
  1974. mems=transformer_outputs.mems,
  1975. hidden_states=transformer_outputs.hidden_states,
  1976. attentions=transformer_outputs.attentions,
  1977. )
  1978. else:
  1979. # during inference, compute the end logits based on beam search
  1980. bsz, slen, hsz = hidden_states.size()
  1981. start_log_probs = nn.functional.softmax(start_logits, dim=-1) # shape (bsz, slen)
  1982. start_top_log_probs, start_top_index = torch.topk(
  1983. start_log_probs, self.start_n_top, dim=-1
  1984. ) # shape (bsz, start_n_top)
  1985. start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz)
  1986. start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz)
  1987. start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz)
  1988. hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(
  1989. start_states
  1990. ) # shape (bsz, slen, start_n_top, hsz)
  1991. p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None
  1992. end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask)
  1993. end_log_probs = nn.functional.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top)
  1994. end_top_log_probs, end_top_index = torch.topk(
  1995. end_log_probs, self.end_n_top, dim=1
  1996. ) # shape (bsz, end_n_top, start_n_top)
  1997. end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top)
  1998. end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top)
  1999. start_states = torch.einsum(
  2000. "blh,bl->bh", hidden_states, start_log_probs
  2001. ) # get the representation of START as weighted sum of hidden states
  2002. cls_logits = self.answer_class(
  2003. hidden_states, start_states=start_states, cls_index=cls_index
  2004. ) # Shape (batch size,): one single `cls_logits` for each sample
  2005. if not return_dict:
  2006. outputs = (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits)
  2007. return outputs + transformer_outputs[1:]
  2008. else:
  2009. return XLNetForQuestionAnsweringOutput(
  2010. start_top_log_probs=start_top_log_probs,
  2011. start_top_index=start_top_index,
  2012. end_top_log_probs=end_top_log_probs,
  2013. end_top_index=end_top_index,
  2014. cls_logits=cls_logits,
  2015. mems=transformer_outputs.mems,
  2016. hidden_states=transformer_outputs.hidden_states,
  2017. attentions=transformer_outputs.attentions,
  2018. )
  2019. __all__ = [
  2020. "XLNetForMultipleChoice",
  2021. "XLNetForQuestionAnswering",
  2022. "XLNetForQuestionAnsweringSimple",
  2023. "XLNetForSequenceClassification",
  2024. "XLNetForTokenClassification",
  2025. "XLNetLMHeadModel",
  2026. "XLNetModel",
  2027. "XLNetPreTrainedModel",
  2028. "load_tf_weights_in_xlnet",
  2029. ]