modeling_bloom.py 55 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252
  1. # coding=utf-8
  2. # Copyright 2022 HuggingFace Inc. team and BigScience workshop.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch BLOOM model."""
  16. import math
  17. import warnings
  18. from typing import Optional, Union
  19. import torch
  20. from torch import nn
  21. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
  22. from torch.nn import functional as F
  23. from ...cache_utils import Cache, DynamicCache, StaticCache
  24. from ...generation import GenerationMixin
  25. from ...modeling_attn_mask_utils import AttentionMaskConverter
  26. from ...modeling_layers import GradientCheckpointingLayer
  27. from ...modeling_outputs import (
  28. BaseModelOutputWithPastAndCrossAttentions,
  29. CausalLMOutputWithCrossAttentions,
  30. QuestionAnsweringModelOutput,
  31. SequenceClassifierOutputWithPast,
  32. TokenClassifierOutput,
  33. )
  34. from ...modeling_utils import PreTrainedModel
  35. from ...utils import (
  36. auto_docstring,
  37. is_torch_flex_attn_available,
  38. logging,
  39. )
  40. from .configuration_bloom import BloomConfig
  41. if is_torch_flex_attn_available():
  42. from torch.nn.attention.flex_attention import BlockMask
  43. from ...integrations.flex_attention import make_flex_block_causal_mask
  44. logger = logging.get_logger(__name__)
  45. def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
  46. """
  47. Link to paper: https://huggingface.co/papers/2108.12409 Alibi tensor is not causal as the original paper mentions, it
  48. relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
  49. `softmax(l+a) = softmax(l)`. Based on
  50. https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
  51. TODO @thomasw21 this doesn't work as nicely due to the masking strategy, and so masking varies slightly.
  52. Args:
  53. Returns tensor shaped (batch_size * num_heads, 1, max_seq_len)
  54. attention_mask (`torch.Tensor`):
  55. Token-wise attention mask, this should be of shape (batch_size, max_seq_len).
  56. num_heads (`int`):
  57. number of heads
  58. dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`):
  59. dtype of the output tensor
  60. """
  61. batch_size, seq_length = attention_mask.shape
  62. closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
  63. base = torch.tensor(
  64. 2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32
  65. )
  66. powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32)
  67. slopes = torch.pow(base, powers)
  68. if closest_power_of_2 != num_heads:
  69. extra_base = torch.tensor(
  70. 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32
  71. )
  72. num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
  73. extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32)
  74. slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
  75. # Note: alibi will added to the attention bias that will be applied to the query, key product of attention
  76. # => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
  77. # => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)
  78. # => the query_length dimension will then be broadcasted correctly
  79. # This is more or less identical to T5's relative position bias:
  80. # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
  81. arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
  82. alibi = slopes[..., None] * arange_tensor
  83. return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)
  84. def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor:
  85. """
  86. Dropout add function
  87. Args:
  88. x (`torch.tensor`):
  89. input tensor
  90. residual (`torch.tensor`):
  91. residual tensor
  92. prob (`float`):
  93. dropout probability
  94. training (`bool`):
  95. training mode
  96. """
  97. out = F.dropout(x, p=prob, training=training)
  98. out = residual + out
  99. return out
  100. def bloom_gelu_forward(x: torch.Tensor) -> torch.Tensor:
  101. """
  102. Custom bias GELU function. Adapted from Megatron-DeepSpeed code. Here we use a simple implementation (inference) to
  103. make the model jitable.
  104. Args:
  105. x (`torch.tensor`):
  106. input hidden states
  107. """
  108. return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
  109. def bloom_gelu_back(g: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
  110. """
  111. gradient of tanh approximation of gelu gradient of actual gelu is: 0.5 * (1. + torch.erf(x * 0.70710678)) +
  112. 0.3989423 * x * torch.exp(-0.5 * x * x)
  113. Args:
  114. g (`torch.tensor`):
  115. gradient output tensor
  116. x (`torch.tensor`):
  117. input tensor
  118. """
  119. x = x[0] # x is a tuple of 1 element, needs to unpack it first
  120. tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
  121. # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
  122. ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
  123. return ff * g
  124. class GeLUFunction(torch.autograd.Function):
  125. @staticmethod
  126. def forward(ctx, input: torch.Tensor) -> torch.Tensor:
  127. ctx.save_for_backward(input)
  128. return bloom_gelu_forward(input)
  129. @staticmethod
  130. def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
  131. input = ctx.saved_tensors
  132. tmp = bloom_gelu_back(grad_output, input)
  133. return tmp
  134. class BloomGelu(nn.Module):
  135. """
  136. BloomBiasGelu wrapper function that make use of the simple function on inference mode to make the model
  137. torchscriptable and use the autograd function in training mode to get the accurate results of the gradients Partly
  138. copied from Megatron-DeepSpeed code and adapted for our needs
  139. See here why autograd functions are not torchscriptable: https://github.com/pytorch/pytorch/issues/22329
  140. """
  141. def __init__(self):
  142. super().__init__()
  143. def forward(self, x: torch.Tensor) -> torch.Tensor:
  144. if self.training:
  145. return GeLUFunction.apply(x)
  146. else:
  147. return bloom_gelu_forward(x)
  148. class BloomAttention(nn.Module):
  149. def __init__(self, config: BloomConfig, layer_idx: Optional[int] = None):
  150. super().__init__()
  151. self.pretraining_tp = config.pretraining_tp
  152. self.slow_but_exact = config.slow_but_exact
  153. self.hidden_size = config.hidden_size
  154. self.num_heads = config.n_head
  155. self.head_dim = self.hidden_size // self.num_heads
  156. self.split_size = self.hidden_size
  157. self.hidden_dropout = config.hidden_dropout
  158. if self.head_dim * self.num_heads != self.hidden_size:
  159. raise ValueError(
  160. f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:"
  161. f" {self.num_heads})."
  162. )
  163. # Layer-wise attention scaling
  164. self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
  165. self.beta = 1.0
  166. self.layer_idx = layer_idx
  167. if layer_idx is None:
  168. logger.warning_once(
  169. f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
  170. "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
  171. "when creating this class."
  172. )
  173. self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True)
  174. self.dense = nn.Linear(self.hidden_size, self.hidden_size)
  175. self.attention_dropout = nn.Dropout(config.attention_dropout)
  176. def _reshape(self, fused_qkv: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  177. """
  178. Split the last dimension into (num_heads, head_dim) and reshapes to (bs, heads, len, dim) shape
  179. without making any copies, results share same memory storage as `fused_qkv`
  180. Args:
  181. fused_qkv (`torch.tensor`): [batch_size, seq_length, num_heads * 3 * head_dim]
  182. Returns:
  183. query: [batch_size, num_heads, seq_length, head_dim]
  184. key: [batch_size, num_heads, seq_length, head_dim]
  185. value: [batch_size, num_heads, seq_length, head_dim]
  186. """
  187. batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
  188. fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim)
  189. query_layer = fused_qkv[..., 0, :].transpose(1, 2)
  190. key_layer = fused_qkv[..., 1, :].transpose(1, 2)
  191. value_layer = fused_qkv[..., 2, :].transpose(1, 2)
  192. return query_layer, key_layer, value_layer
  193. def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
  194. """
  195. Merge heads together over the last dimension
  196. Args:
  197. x (`torch.tensor`): [batch_size * num_heads, seq_length, head_dim]
  198. Returns:
  199. torch.tensor: [batch_size, seq_length, num_heads * head_dim]
  200. """
  201. # What we want to achieve is:
  202. # batch_size * num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads * head_dim
  203. batch_size_and_num_heads, seq_length, _ = x.shape
  204. batch_size = batch_size_and_num_heads // self.num_heads
  205. # First view to decompose the batch size
  206. # batch_size * num_heads, seq_length, head_dim -> batch_size, num_heads, seq_length, head_dim
  207. x = x.view(batch_size, self.num_heads, seq_length, self.head_dim)
  208. # batch_size, num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads, head_dim
  209. x = x.permute(0, 2, 1, 3)
  210. # batch_size, seq_length, num_heads, head_dim -> batch_size, seq_length, num_heads * head_dim
  211. return x.reshape(batch_size, seq_length, self.num_heads * self.head_dim)
  212. def forward(
  213. self,
  214. hidden_states: torch.Tensor,
  215. residual: torch.Tensor,
  216. alibi: torch.Tensor,
  217. attention_mask: torch.Tensor,
  218. layer_past: Optional[Cache] = None,
  219. head_mask: Optional[torch.Tensor] = None,
  220. use_cache: bool = False,
  221. output_attentions: bool = False,
  222. cache_position: Optional[torch.LongTensor] = None,
  223. ):
  224. batch_size, q_length, _ = hidden_states.shape
  225. fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
  226. # 3 x [batch_size, num_heads, seq_length, head_dim]
  227. query_layer, key_layer, value_layer = self._reshape(fused_qkv)
  228. if layer_past is not None:
  229. cache_kwargs = {"cache_position": cache_position}
  230. key_layer, value_layer = layer_past.update(key_layer, value_layer, self.layer_idx, cache_kwargs)
  231. # reshape qkv for further computations
  232. query_layer = query_layer.reshape(batch_size * self.num_heads, -1, self.head_dim)
  233. key_layer = key_layer.reshape(batch_size * self.num_heads, -1, self.head_dim).transpose(-1, -2)
  234. value_layer = value_layer.reshape(batch_size * self.num_heads, -1, self.head_dim)
  235. # [batch_size * num_heads, q_length, kv_length]
  236. attention_scores = alibi.baddbmm(
  237. batch1=query_layer,
  238. batch2=key_layer,
  239. beta=self.beta,
  240. alpha=self.inv_norm_factor,
  241. )
  242. # change view to [batch_size, num_heads, q_length, kv_length]
  243. attn_weights = attention_scores.view(batch_size, self.num_heads, q_length, -1)
  244. if attention_mask is not None: # no matter the length, we just slice it
  245. causal_mask = attention_mask[:, :, :, : key_layer.shape[-1]]
  246. attn_weights = attn_weights + causal_mask
  247. # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype
  248. attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_layer.dtype)
  249. # [batch_size, num_heads, q_length, kv_length]
  250. attention_probs = self.attention_dropout(attention_probs)
  251. if head_mask is not None:
  252. attention_probs = attention_probs * head_mask
  253. # change view [batch_size x num_heads, q_length, kv_length]
  254. attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, -1)
  255. # matmul: [batch_size * num_heads, q_length, head_dim]
  256. context_layer = torch.bmm(attention_probs_reshaped, value_layer)
  257. # change view [batch_size, q_length, num_heads * head_dim]
  258. context_layer = self._merge_heads(context_layer)
  259. # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
  260. if self.pretraining_tp > 1 and self.slow_but_exact:
  261. slices = self.hidden_size / self.pretraining_tp
  262. output_tensor = torch.zeros_like(context_layer)
  263. for i in range(self.pretraining_tp):
  264. output_tensor = output_tensor + F.linear(
  265. context_layer[:, :, int(i * slices) : int((i + 1) * slices)],
  266. self.dense.weight[:, int(i * slices) : int((i + 1) * slices)],
  267. )
  268. else:
  269. output_tensor = self.dense(context_layer)
  270. output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training)
  271. return output_tensor, attention_probs
  272. class BloomMLP(nn.Module):
  273. def __init__(self, config: BloomConfig):
  274. super().__init__()
  275. hidden_size = config.hidden_size
  276. self.pretraining_tp = config.pretraining_tp
  277. self.slow_but_exact = config.slow_but_exact
  278. self.dense_h_to_4h = nn.Linear(hidden_size, 4 * hidden_size)
  279. self.gelu_impl = BloomGelu()
  280. self.dense_4h_to_h = nn.Linear(4 * hidden_size, hidden_size)
  281. self.hidden_dropout = config.hidden_dropout
  282. def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
  283. hidden_states = self.gelu_impl(self.dense_h_to_4h(hidden_states))
  284. if self.pretraining_tp > 1 and self.slow_but_exact:
  285. intermediate_output = torch.zeros_like(residual)
  286. slices = self.dense_4h_to_h.weight.shape[-1] / self.pretraining_tp
  287. for i in range(self.pretraining_tp):
  288. intermediate_output = intermediate_output + F.linear(
  289. hidden_states[:, :, int(i * slices) : int((i + 1) * slices)],
  290. self.dense_4h_to_h.weight[:, int(i * slices) : int((i + 1) * slices)],
  291. )
  292. else:
  293. intermediate_output = self.dense_4h_to_h(hidden_states)
  294. output = dropout_add(intermediate_output, residual, self.hidden_dropout, self.training)
  295. return output
  296. class BloomBlock(GradientCheckpointingLayer):
  297. def __init__(self, config: BloomConfig, layer_idx: Optional[int] = None):
  298. super().__init__()
  299. hidden_size = config.hidden_size
  300. self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
  301. self.num_heads = config.n_head
  302. self.self_attention = BloomAttention(config, layer_idx)
  303. self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
  304. self.mlp = BloomMLP(config)
  305. self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
  306. self.hidden_dropout = config.hidden_dropout
  307. def forward(
  308. self,
  309. hidden_states: torch.Tensor,
  310. alibi: torch.Tensor,
  311. attention_mask: torch.Tensor,
  312. layer_past: Optional[Cache] = None,
  313. head_mask: Optional[torch.Tensor] = None,
  314. use_cache: bool = False,
  315. output_attentions: bool = False,
  316. cache_position: Optional[torch.LongTensor] = None,
  317. ):
  318. # hidden_states: [batch_size, seq_length, hidden_size]
  319. # Layer norm at the beginning of the transformer layer.
  320. layernorm_output = self.input_layernorm(hidden_states)
  321. # Layer norm post the self attention.
  322. if self.apply_residual_connection_post_layernorm:
  323. residual = layernorm_output
  324. else:
  325. residual = hidden_states
  326. # Self attention.
  327. attention_output, attn_weights = self.self_attention(
  328. layernorm_output,
  329. residual,
  330. layer_past=layer_past,
  331. attention_mask=attention_mask,
  332. alibi=alibi,
  333. head_mask=head_mask,
  334. use_cache=use_cache,
  335. output_attentions=output_attentions,
  336. cache_position=cache_position,
  337. )
  338. layernorm_output = self.post_attention_layernorm(attention_output)
  339. # Get residual
  340. if self.apply_residual_connection_post_layernorm:
  341. residual = layernorm_output
  342. else:
  343. residual = attention_output
  344. # MLP.
  345. output = self.mlp(layernorm_output, residual)
  346. return output, attn_weights # hidden_states, attentions
  347. @auto_docstring
  348. class BloomPreTrainedModel(PreTrainedModel):
  349. config: BloomConfig
  350. base_model_prefix = "transformer"
  351. supports_gradient_checkpointing = True
  352. _no_split_modules = ["BloomBlock"]
  353. _skip_keys_device_placement = "past_key_values"
  354. _can_compile_fullgraph = True
  355. def __init__(self, *inputs, **kwargs):
  356. super().__init__(*inputs, **kwargs)
  357. def _init_weights(self, module: nn.Module):
  358. """Initialize the weights."""
  359. if isinstance(module, nn.Linear):
  360. # Slightly different from the TF version which uses truncated_normal for initialization
  361. # cf https://github.com/pytorch/pytorch/pull/5617
  362. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  363. if module.bias is not None:
  364. module.bias.data.zero_()
  365. elif isinstance(module, nn.Embedding):
  366. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  367. if module.padding_idx is not None:
  368. module.weight.data[module.padding_idx].zero_()
  369. elif isinstance(module, LayerNorm):
  370. module.bias.data.zero_()
  371. module.weight.data.fill_(1.0)
  372. @auto_docstring
  373. class BloomModel(BloomPreTrainedModel):
  374. def __init__(self, config: BloomConfig):
  375. super().__init__(config)
  376. self.embed_dim = config.hidden_size
  377. self.num_heads = config.n_head
  378. # Embedding + LN Embedding
  379. self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
  380. self.word_embeddings_layernorm = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
  381. # Transformer blocks
  382. self.h = nn.ModuleList([BloomBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)])
  383. # Final Layer Norm
  384. self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
  385. self.gradient_checkpointing = False
  386. # Initialize weights and apply final processing
  387. self.post_init()
  388. def build_alibi_tensor(self, attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
  389. return build_alibi_tensor(attention_mask, num_heads, dtype)
  390. def get_input_embeddings(self):
  391. return self.word_embeddings
  392. def set_input_embeddings(self, new_embeddings: torch.Tensor):
  393. self.word_embeddings = new_embeddings
  394. @auto_docstring
  395. def forward(
  396. self,
  397. input_ids: Optional[torch.LongTensor] = None,
  398. past_key_values: Optional[Union[Cache, tuple[tuple[torch.Tensor, torch.Tensor], ...]]] = None,
  399. attention_mask: Optional[torch.Tensor] = None,
  400. head_mask: Optional[torch.LongTensor] = None,
  401. inputs_embeds: Optional[torch.LongTensor] = None,
  402. use_cache: Optional[bool] = None,
  403. output_attentions: Optional[bool] = None,
  404. output_hidden_states: Optional[bool] = None,
  405. return_dict: Optional[bool] = None,
  406. cache_position: Optional[torch.LongTensor] = None,
  407. **deprecated_arguments,
  408. ) -> Union[tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
  409. r"""
  410. input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
  411. `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()`
  412. (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.
  413. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
  414. `input_ids`.
  415. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  416. [`PreTrainedTokenizer.__call__`] for details.
  417. [What are input IDs?](../glossary#input-ids)
  418. """
  419. if deprecated_arguments.pop("position_ids", False) is not False:
  420. # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
  421. warnings.warn(
  422. "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
  423. " passing `position_ids`.",
  424. FutureWarning,
  425. )
  426. if len(deprecated_arguments) > 0:
  427. raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
  428. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  429. output_hidden_states = (
  430. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  431. )
  432. use_cache = use_cache if use_cache is not None else self.config.use_cache
  433. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  434. if (input_ids is None) ^ (inputs_embeds is not None):
  435. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  436. if self.gradient_checkpointing and self.training and use_cache:
  437. logger.warning_once(
  438. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  439. )
  440. use_cache = False
  441. if inputs_embeds is None:
  442. inputs_embeds = self.word_embeddings(input_ids)
  443. if use_cache and past_key_values is None:
  444. past_key_values = DynamicCache(config=self.config)
  445. batch_size, seq_length, _ = inputs_embeds.shape
  446. past_length = past_key_values.get_seq_length() if past_key_values is not None else 0
  447. seq_length_with_past = seq_length + past_length
  448. if cache_position is None:
  449. cache_position = torch.arange(past_length, past_length + seq_length, device=inputs_embeds.device)
  450. # Prepare head mask if needed
  451. # 1.0 in head_mask indicate we keep the head
  452. # attention_probs has shape batch_size x num_heads x N x N
  453. # head_mask has shape n_layer x batch x num_heads x N x N
  454. head_mask = self.get_head_mask(head_mask, self.config.n_layer)
  455. hidden_states = self.word_embeddings_layernorm(inputs_embeds)
  456. all_self_attentions = () if output_attentions else None
  457. all_hidden_states = () if output_hidden_states else None
  458. # Compute alibi tensor: check build_alibi_tensor documentation
  459. if attention_mask is None:
  460. attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
  461. else:
  462. attention_mask = attention_mask.to(hidden_states.device)
  463. alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
  464. causal_mask = self._update_causal_mask(
  465. attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
  466. )
  467. for i, block in enumerate(self.h):
  468. if output_hidden_states:
  469. all_hidden_states = all_hidden_states + (hidden_states,)
  470. outputs = block(
  471. hidden_states,
  472. layer_past=past_key_values,
  473. attention_mask=causal_mask,
  474. head_mask=head_mask[i],
  475. use_cache=use_cache,
  476. output_attentions=output_attentions,
  477. alibi=alibi,
  478. cache_position=cache_position,
  479. )
  480. hidden_states = outputs[0]
  481. if output_attentions:
  482. all_self_attentions = all_self_attentions + (outputs[1],)
  483. # Add last hidden state
  484. hidden_states = self.ln_f(hidden_states)
  485. if output_hidden_states:
  486. all_hidden_states = all_hidden_states + (hidden_states,)
  487. if not return_dict:
  488. return tuple(
  489. v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions] if v is not None
  490. )
  491. return BaseModelOutputWithPastAndCrossAttentions(
  492. last_hidden_state=hidden_states,
  493. past_key_values=past_key_values,
  494. hidden_states=all_hidden_states,
  495. attentions=all_self_attentions,
  496. )
  497. # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
  498. def _update_causal_mask(
  499. self,
  500. attention_mask: Union[torch.Tensor, "BlockMask"],
  501. input_tensor: torch.Tensor,
  502. cache_position: torch.Tensor,
  503. past_key_values: Cache,
  504. output_attentions: bool = False,
  505. ):
  506. if self.config._attn_implementation == "flash_attention_2":
  507. if attention_mask is not None and (attention_mask == 0.0).any():
  508. return attention_mask
  509. return None
  510. if self.config._attn_implementation == "flex_attention":
  511. if isinstance(attention_mask, torch.Tensor):
  512. attention_mask = make_flex_block_causal_mask(attention_mask)
  513. return attention_mask
  514. # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
  515. # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
  516. # to infer the attention mask.
  517. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  518. using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
  519. # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
  520. if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
  521. if AttentionMaskConverter._ignore_causal_mask_sdpa(
  522. attention_mask,
  523. inputs_embeds=input_tensor,
  524. past_key_values_length=past_seen_tokens,
  525. is_training=self.training,
  526. ):
  527. return None
  528. dtype = input_tensor.dtype
  529. sequence_length = input_tensor.shape[1]
  530. if using_compilable_cache:
  531. target_length = past_key_values.get_max_cache_shape()
  532. else:
  533. target_length = (
  534. attention_mask.shape[-1]
  535. if isinstance(attention_mask, torch.Tensor)
  536. else past_seen_tokens + sequence_length + 1
  537. )
  538. # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
  539. causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
  540. attention_mask,
  541. sequence_length=sequence_length,
  542. target_length=target_length,
  543. dtype=dtype,
  544. cache_position=cache_position,
  545. batch_size=input_tensor.shape[0],
  546. )
  547. if (
  548. self.config._attn_implementation == "sdpa"
  549. and attention_mask is not None
  550. and attention_mask.device.type in ["cuda", "xpu", "npu"]
  551. and not output_attentions
  552. ):
  553. # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
  554. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
  555. # Details: https://github.com/pytorch/pytorch/issues/110213
  556. min_dtype = torch.finfo(dtype).min
  557. causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
  558. return causal_mask
  559. @staticmethod
  560. # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
  561. def _prepare_4d_causal_attention_mask_with_cache_position(
  562. attention_mask: torch.Tensor,
  563. sequence_length: int,
  564. target_length: int,
  565. dtype: torch.dtype,
  566. cache_position: torch.Tensor,
  567. batch_size: int,
  568. **kwargs,
  569. ):
  570. """
  571. Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
  572. `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
  573. Args:
  574. attention_mask (`torch.Tensor`):
  575. A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
  576. `(batch_size, 1, query_length, key_value_length)`.
  577. sequence_length (`int`):
  578. The sequence length being processed.
  579. target_length (`int`):
  580. The target length: when generating with static cache, the mask should be as long as the static cache,
  581. to account for the 0 padding, the part of the cache that is not filled yet.
  582. dtype (`torch.dtype`):
  583. The dtype to use for the 4D attention mask.
  584. cache_position (`torch.Tensor`):
  585. Indices depicting the position of the input sequence tokens in the sequence.
  586. batch_size (`torch.Tensor`):
  587. Batch size.
  588. """
  589. if attention_mask is not None and attention_mask.dim() == 4:
  590. # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
  591. causal_mask = attention_mask
  592. else:
  593. min_dtype = torch.finfo(dtype).min
  594. causal_mask = torch.full(
  595. (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
  596. )
  597. if sequence_length != 1:
  598. causal_mask = torch.triu(causal_mask, diagonal=1)
  599. causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
  600. causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
  601. if attention_mask is not None:
  602. causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
  603. mask_length = attention_mask.shape[-1]
  604. padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
  605. causal_mask.device
  606. )
  607. padding_mask = padding_mask == 0
  608. causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
  609. padding_mask, min_dtype
  610. )
  611. return causal_mask
  612. @auto_docstring(
  613. custom_intro="""
  614. The Bloom Model transformer with a language modeling head on top (linear layer with weights tied to the input
  615. embeddings).
  616. """
  617. )
  618. class BloomForCausalLM(BloomPreTrainedModel, GenerationMixin):
  619. _tied_weights_keys = ["lm_head.weight"]
  620. def __init__(self, config: BloomConfig):
  621. super().__init__(config)
  622. self.transformer = BloomModel(config)
  623. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  624. # Initialize weights and apply final processing
  625. self.post_init()
  626. def set_output_embeddings(self, new_embeddings: torch.Tensor):
  627. self.lm_head = new_embeddings
  628. def prepare_inputs_for_generation(
  629. self,
  630. input_ids,
  631. past_key_values=None,
  632. attention_mask=None,
  633. inputs_embeds=None,
  634. cache_position=None,
  635. use_cache=True,
  636. **kwargs,
  637. ):
  638. # Overwritten because of the fixed-shape attention mask creation
  639. # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
  640. # Exception 1: when passing input_embeds, input_ids may be missing entries
  641. # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
  642. # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
  643. # (we can't check exception 3 while compiling)
  644. # Exception 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and
  645. # generate the first token for each sequence. Later use the generated Input ids for continuation.
  646. if past_key_values is not None:
  647. if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4
  648. inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
  649. elif (
  650. inputs_embeds is not None # Exception 1
  651. or cache_position[-1] >= input_ids.shape[1] # Exception 3
  652. ):
  653. input_ids = input_ids[:, -cache_position.shape[0] :]
  654. elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
  655. input_ids = input_ids[:, cache_position]
  656. # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
  657. if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]:
  658. model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
  659. else:
  660. # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the
  661. # input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in
  662. # the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
  663. model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
  664. # This part differs from other models because BLOOM needs a 2D mask to construct alibi tensor
  665. # The only difference is the usage of 2D instead of 4D mask, but the shape will be static
  666. if isinstance(past_key_values, StaticCache) and attention_mask is not None:
  667. target_length = past_key_values.get_max_cache_shape()
  668. batch_size, seq_length = attention_mask.shape
  669. diff = target_length - seq_length
  670. new_attn_mask = torch.zeros(batch_size, diff, device=attention_mask.device, dtype=attention_mask.dtype)
  671. attention_mask = torch.cat(
  672. [attention_mask, new_attn_mask],
  673. dim=-1,
  674. )
  675. model_inputs.update(
  676. {
  677. "cache_position": cache_position,
  678. "past_key_values": past_key_values,
  679. "use_cache": use_cache,
  680. "attention_mask": attention_mask,
  681. }
  682. )
  683. # Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
  684. for key, value in kwargs.items():
  685. if key not in model_inputs:
  686. model_inputs[key] = value
  687. return model_inputs
  688. @auto_docstring
  689. def forward(
  690. self,
  691. input_ids: Optional[torch.LongTensor] = None,
  692. past_key_values: Optional[Union[Cache, tuple[tuple[torch.Tensor, torch.Tensor], ...]]] = None,
  693. attention_mask: Optional[torch.Tensor] = None,
  694. head_mask: Optional[torch.Tensor] = None,
  695. inputs_embeds: Optional[torch.Tensor] = None,
  696. labels: Optional[torch.Tensor] = None,
  697. use_cache: Optional[bool] = None,
  698. output_attentions: Optional[bool] = None,
  699. output_hidden_states: Optional[bool] = None,
  700. return_dict: Optional[bool] = None,
  701. cache_position: Optional[torch.LongTensor] = None,
  702. **deprecated_arguments,
  703. ) -> Union[tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
  704. r"""
  705. input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
  706. `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()`
  707. (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.
  708. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
  709. `input_ids`.
  710. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  711. [`PreTrainedTokenizer.__call__`] for details.
  712. [What are input IDs?](../glossary#input-ids)
  713. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  714. Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
  715. `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
  716. are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
  717. """
  718. # Bloom has deprecated kwargs, so we need to pop num_items_in_batch explicitly
  719. num_items_in_batch = deprecated_arguments.pop("num_items_in_batch", None)
  720. if deprecated_arguments.pop("position_ids", False) is not False:
  721. # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
  722. warnings.warn(
  723. "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
  724. " passing `position_ids`.",
  725. FutureWarning,
  726. )
  727. if len(deprecated_arguments) > 0:
  728. raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
  729. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  730. transformer_outputs = self.transformer(
  731. input_ids,
  732. past_key_values=past_key_values,
  733. attention_mask=attention_mask,
  734. head_mask=head_mask,
  735. inputs_embeds=inputs_embeds,
  736. use_cache=use_cache,
  737. output_attentions=output_attentions,
  738. output_hidden_states=output_hidden_states,
  739. return_dict=return_dict,
  740. cache_position=cache_position,
  741. )
  742. hidden_states = transformer_outputs[0]
  743. lm_logits = self.lm_head(hidden_states)
  744. loss = None
  745. if labels is not None:
  746. # move labels to correct device to enable model parallelism
  747. labels = labels.to(lm_logits.device)
  748. # Flatten the tokens
  749. loss = self.loss_function(
  750. lm_logits,
  751. labels,
  752. vocab_size=self.config.vocab_size,
  753. num_items_in_batch=num_items_in_batch,
  754. )
  755. if not return_dict:
  756. output = (lm_logits,) + transformer_outputs[1:]
  757. return ((loss,) + output) if loss is not None else output
  758. return CausalLMOutputWithCrossAttentions(
  759. loss=loss,
  760. logits=lm_logits,
  761. past_key_values=transformer_outputs.past_key_values,
  762. hidden_states=transformer_outputs.hidden_states,
  763. attentions=transformer_outputs.attentions,
  764. )
  765. @auto_docstring(
  766. custom_intro="""
  767. The Bloom Model transformer with a sequence classification head on top (linear layer).
  768. [`BloomForSequenceClassification`] uses the last token in order to do the classification, as other causal models
  769. (e.g. GPT-1) do.
  770. Since it does classification on the last token, it requires to know the position of the last token. If a
  771. `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
  772. no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
  773. padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
  774. each row of the batch).
  775. """
  776. )
  777. class BloomForSequenceClassification(BloomPreTrainedModel):
  778. def __init__(self, config: BloomConfig):
  779. super().__init__(config)
  780. self.num_labels = config.num_labels
  781. self.transformer = BloomModel(config)
  782. self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
  783. # Initialize weights and apply final processing
  784. self.post_init()
  785. @auto_docstring
  786. def forward(
  787. self,
  788. input_ids: Optional[torch.LongTensor] = None,
  789. past_key_values: Optional[Union[Cache, tuple[tuple[torch.Tensor, torch.Tensor], ...]]] = None,
  790. attention_mask: Optional[torch.Tensor] = None,
  791. head_mask: Optional[torch.Tensor] = None,
  792. inputs_embeds: Optional[torch.Tensor] = None,
  793. labels: Optional[torch.Tensor] = None,
  794. use_cache: Optional[bool] = None,
  795. output_attentions: Optional[bool] = None,
  796. output_hidden_states: Optional[bool] = None,
  797. return_dict: Optional[bool] = None,
  798. **deprecated_arguments,
  799. ) -> Union[tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
  800. r"""
  801. input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
  802. `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()`
  803. (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.
  804. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
  805. `input_ids`.
  806. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  807. [`PreTrainedTokenizer.__call__`] for details.
  808. [What are input IDs?](../glossary#input-ids)
  809. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  810. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  811. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  812. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  813. """
  814. if deprecated_arguments.pop("position_ids", False) is not False:
  815. # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
  816. warnings.warn(
  817. "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
  818. " passing `position_ids`.",
  819. FutureWarning,
  820. )
  821. if len(deprecated_arguments) > 0:
  822. raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
  823. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  824. transformer_outputs = self.transformer(
  825. input_ids,
  826. past_key_values=past_key_values,
  827. attention_mask=attention_mask,
  828. head_mask=head_mask,
  829. inputs_embeds=inputs_embeds,
  830. use_cache=use_cache,
  831. output_attentions=output_attentions,
  832. output_hidden_states=output_hidden_states,
  833. return_dict=return_dict,
  834. )
  835. hidden_states = transformer_outputs[0]
  836. logits = self.score(hidden_states)
  837. if input_ids is not None:
  838. batch_size = input_ids.shape[0]
  839. else:
  840. batch_size = inputs_embeds.shape[0]
  841. if self.config.pad_token_id is None and batch_size != 1:
  842. raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
  843. if self.config.pad_token_id is None:
  844. last_non_pad_token = -1
  845. elif input_ids is not None:
  846. # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
  847. non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
  848. token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
  849. last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
  850. else:
  851. last_non_pad_token = -1
  852. logger.warning_once(
  853. f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
  854. "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
  855. )
  856. pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
  857. loss = None
  858. if labels is not None:
  859. if self.config.problem_type is None:
  860. if self.num_labels == 1:
  861. self.config.problem_type = "regression"
  862. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  863. self.config.problem_type = "single_label_classification"
  864. else:
  865. self.config.problem_type = "multi_label_classification"
  866. if self.config.problem_type == "regression":
  867. loss_fct = MSELoss()
  868. if self.num_labels == 1:
  869. loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
  870. else:
  871. loss = loss_fct(pooled_logits, labels)
  872. elif self.config.problem_type == "single_label_classification":
  873. loss_fct = CrossEntropyLoss()
  874. loss = loss_fct(pooled_logits, labels)
  875. elif self.config.problem_type == "multi_label_classification":
  876. loss_fct = BCEWithLogitsLoss()
  877. loss = loss_fct(pooled_logits, labels)
  878. if not return_dict:
  879. output = (pooled_logits,) + transformer_outputs[1:]
  880. return ((loss,) + output) if loss is not None else output
  881. return SequenceClassifierOutputWithPast(
  882. loss=loss,
  883. logits=pooled_logits,
  884. past_key_values=transformer_outputs.past_key_values,
  885. hidden_states=transformer_outputs.hidden_states,
  886. attentions=transformer_outputs.attentions,
  887. )
  888. @auto_docstring
  889. class BloomForTokenClassification(BloomPreTrainedModel):
  890. def __init__(self, config: BloomConfig):
  891. super().__init__(config)
  892. self.num_labels = config.num_labels
  893. self.transformer = BloomModel(config)
  894. if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
  895. classifier_dropout = config.classifier_dropout
  896. elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
  897. classifier_dropout = config.hidden_dropout
  898. else:
  899. classifier_dropout = 0.1
  900. self.dropout = nn.Dropout(classifier_dropout)
  901. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  902. # Initialize weights and apply final processing
  903. self.post_init()
  904. @auto_docstring
  905. def forward(
  906. self,
  907. input_ids: Optional[torch.LongTensor] = None,
  908. past_key_values: Optional[Union[Cache, tuple[tuple[torch.Tensor, torch.Tensor], ...]]] = None,
  909. attention_mask: Optional[torch.Tensor] = None,
  910. head_mask: Optional[torch.Tensor] = None,
  911. inputs_embeds: Optional[torch.Tensor] = None,
  912. labels: Optional[torch.Tensor] = None,
  913. use_cache: Optional[bool] = None,
  914. output_attentions: Optional[bool] = None,
  915. output_hidden_states: Optional[bool] = None,
  916. return_dict: Optional[bool] = None,
  917. **deprecated_arguments,
  918. ) -> Union[tuple[torch.Tensor], TokenClassifierOutput]:
  919. r"""
  920. input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
  921. `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()`
  922. (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.
  923. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
  924. `input_ids`.
  925. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  926. [`PreTrainedTokenizer.__call__`] for details.
  927. [What are input IDs?](../glossary#input-ids)
  928. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  929. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  930. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  931. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  932. """
  933. if deprecated_arguments.pop("position_ids", False) is not False:
  934. # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
  935. warnings.warn(
  936. "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
  937. " passing `position_ids`.",
  938. FutureWarning,
  939. )
  940. if len(deprecated_arguments) > 0:
  941. raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
  942. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  943. transformer_outputs = self.transformer(
  944. input_ids,
  945. past_key_values=past_key_values,
  946. attention_mask=attention_mask,
  947. head_mask=head_mask,
  948. inputs_embeds=inputs_embeds,
  949. use_cache=use_cache,
  950. output_attentions=output_attentions,
  951. output_hidden_states=output_hidden_states,
  952. return_dict=return_dict,
  953. )
  954. hidden_states = transformer_outputs[0]
  955. hidden_states = self.dropout(hidden_states)
  956. logits = self.classifier(hidden_states)
  957. loss = None
  958. if labels is not None:
  959. # move labels to correct device to enable model parallelism
  960. labels = labels.to(logits.device)
  961. batch_size, seq_length = labels.shape
  962. loss_fct = CrossEntropyLoss()
  963. loss = loss_fct(
  964. logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length)
  965. )
  966. if not return_dict:
  967. output = (logits,) + transformer_outputs[2:]
  968. return ((loss,) + output) if loss is not None else output
  969. return TokenClassifierOutput(
  970. loss=loss,
  971. logits=logits,
  972. hidden_states=transformer_outputs.hidden_states,
  973. attentions=transformer_outputs.attentions,
  974. )
  975. @auto_docstring
  976. class BloomForQuestionAnswering(BloomPreTrainedModel):
  977. def __init__(self, config):
  978. super().__init__(config)
  979. self.transformer = BloomModel(config)
  980. self.qa_outputs = nn.Linear(config.hidden_size, 2)
  981. # Initialize weights and apply final processing
  982. self.post_init()
  983. @auto_docstring
  984. def forward(
  985. self,
  986. input_ids: Optional[torch.LongTensor] = None,
  987. attention_mask: Optional[torch.FloatTensor] = None,
  988. position_ids: Optional[torch.LongTensor] = None,
  989. head_mask: Optional[torch.FloatTensor] = None,
  990. inputs_embeds: Optional[torch.FloatTensor] = None,
  991. start_positions: Optional[torch.LongTensor] = None,
  992. end_positions: Optional[torch.LongTensor] = None,
  993. output_attentions: Optional[bool] = None,
  994. output_hidden_states: Optional[bool] = None,
  995. return_dict: Optional[bool] = None,
  996. ) -> Union[tuple, QuestionAnsweringModelOutput]:
  997. r"""
  998. input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
  999. `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()`
  1000. (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.
  1001. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
  1002. `input_ids`.
  1003. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1004. [`PreTrainedTokenizer.__call__`] for details.
  1005. [What are input IDs?](../glossary#input-ids)
  1006. """
  1007. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1008. outputs = self.transformer(
  1009. input_ids,
  1010. attention_mask=attention_mask,
  1011. position_ids=position_ids,
  1012. head_mask=head_mask,
  1013. inputs_embeds=inputs_embeds,
  1014. output_attentions=output_attentions,
  1015. output_hidden_states=output_hidden_states,
  1016. return_dict=return_dict,
  1017. )
  1018. sequence_output = outputs[0]
  1019. logits = self.qa_outputs(sequence_output)
  1020. start_logits, end_logits = logits.split(1, dim=-1)
  1021. start_logits = start_logits.squeeze(-1).contiguous()
  1022. end_logits = end_logits.squeeze(-1).contiguous()
  1023. total_loss = None
  1024. if start_positions is not None and end_positions is not None:
  1025. # If we are on multi-GPU, split add a dimension
  1026. if len(start_positions.size()) > 1:
  1027. start_positions = start_positions.squeeze(-1)
  1028. if len(end_positions.size()) > 1:
  1029. end_positions = end_positions.squeeze(-1)
  1030. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  1031. ignored_index = start_logits.size(1)
  1032. start_positions = start_positions.clamp(0, ignored_index)
  1033. end_positions = end_positions.clamp(0, ignored_index)
  1034. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  1035. start_loss = loss_fct(start_logits, start_positions)
  1036. end_loss = loss_fct(end_logits, end_positions)
  1037. total_loss = (start_loss + end_loss) / 2
  1038. if not return_dict:
  1039. output = (start_logits, end_logits) + outputs[2:]
  1040. return ((total_loss,) + output) if total_loss is not None else output
  1041. return QuestionAnsweringModelOutput(
  1042. loss=total_loss,
  1043. start_logits=start_logits,
  1044. end_logits=end_logits,
  1045. hidden_states=outputs.hidden_states,
  1046. attentions=outputs.attentions,
  1047. )
  1048. __all__ = [
  1049. "BloomForCausalLM",
  1050. "BloomModel",
  1051. "BloomPreTrainedModel",
  1052. "BloomForSequenceClassification",
  1053. "BloomForTokenClassification",
  1054. "BloomForQuestionAnswering",
  1055. ]