modeling_mpt.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824
  1. # coding=utf-8
  2. # Copyright 2023 HuggingFace Inc. team and MosaicML NLP team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch MPT model."""
  16. import math
  17. from typing import Optional, Union
  18. import torch
  19. from torch import nn
  20. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
  21. from torch.nn import functional as F
  22. from ...cache_utils import Cache, DynamicCache
  23. from ...generation import GenerationMixin
  24. from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
  25. from ...modeling_layers import GradientCheckpointingLayer
  26. from ...modeling_outputs import (
  27. BaseModelOutputWithPastAndCrossAttentions,
  28. CausalLMOutputWithCrossAttentions,
  29. QuestionAnsweringModelOutput,
  30. SequenceClassifierOutputWithPast,
  31. TokenClassifierOutput,
  32. )
  33. from ...modeling_utils import PreTrainedModel
  34. from ...utils import auto_docstring, logging
  35. from ...utils.deprecation import deprecate_kwarg
  36. from .configuration_mpt import MptConfig
  37. logger = logging.get_logger(__name__)
  38. def build_mpt_alibi_tensor(num_heads, sequence_length, alibi_bias_max=8, device=None):
  39. r"""
  40. Link to paper: https://huggingface.co/papers/2108.12409 - Alibi tensor is not causal as the original paper mentions, it
  41. relies on a translation invariance of softmax for quick implementation. This implementation has been copied from
  42. the alibi implementation of MPT source code that led to slightly different results than the Bloom alibi:
  43. https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L292
  44. """
  45. alibi = torch.arange(1 - sequence_length, 1, dtype=torch.int32, device=device).view(1, 1, 1, sequence_length)
  46. num_heads_power_of_2 = 2 ** math.ceil(math.log2(num_heads))
  47. base = torch.arange(1, num_heads_power_of_2 + 1, dtype=torch.int64, device=device).float()
  48. base = base * (alibi_bias_max / num_heads_power_of_2)
  49. slopes = 1.0 / torch.pow(2, base)
  50. slopes = slopes.view(1, num_heads_power_of_2, 1, 1)
  51. if num_heads_power_of_2 != num_heads:
  52. slopes = torch.concat([slopes[:, 1::2, ...], slopes[:, ::2, ...]], dim=1)[:, :num_heads, ...]
  53. alibi = alibi * slopes
  54. return alibi.squeeze(0)
  55. class MptAttention(nn.Module):
  56. """Multi-head self attention.
  57. Using torch or triton attention implementation enables user to also use additive bias.
  58. """
  59. def __init__(self, config: MptConfig, layer_idx: Optional[int] = None):
  60. super().__init__()
  61. self.hidden_size = config.hidden_size
  62. self.n_heads = config.n_heads
  63. self.max_seq_length = config.max_seq_len
  64. self.head_dim = self.hidden_size // self.n_heads
  65. self.softmax_scale = config.attn_config.softmax_scale
  66. if self.softmax_scale is None:
  67. self.softmax_scale = 1 / math.sqrt(self.hidden_size / self.n_heads)
  68. self.attn_dropout_p = config.attn_config.attn_pdrop
  69. self.clip_qkv = config.attn_config.clip_qkv
  70. self.Wqkv = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=False)
  71. self.out_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
  72. self.layer_idx = layer_idx
  73. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  74. def forward(
  75. self,
  76. hidden_states: torch.Tensor,
  77. position_bias: torch.Tensor,
  78. past_key_values: Optional[Cache] = None,
  79. attention_mask: Optional[torch.Tensor] = None,
  80. cache_position: Optional[torch.Tensor] = None,
  81. ):
  82. batch_size, seq_length = hidden_states.shape[:2]
  83. mixed_qkv = self.Wqkv(hidden_states)
  84. if self.clip_qkv:
  85. mixed_qkv = mixed_qkv.clamp(min=-self.clip_qkv, max=self.clip_qkv)
  86. query_states, key_states, value_states = mixed_qkv.chunk(3, dim=2)
  87. query_states = query_states.reshape(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2)
  88. key_states = key_states.reshape(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2)
  89. value_states = value_states.reshape(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2)
  90. if past_key_values is not None:
  91. cache_kwargs = {"cache_position": cache_position}
  92. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
  93. attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) * self.softmax_scale
  94. query_length = seq_length if past_key_values is None else seq_length + past_key_values.get_seq_length()
  95. if position_bias is not None:
  96. if len(position_bias.shape) != 3:
  97. raise ValueError(f"Expecting position_bias shape to be 3 dimensions, got {len(position_bias.shape)}")
  98. key_length = key_states.shape[-2]
  99. position_bias_query_index = max(0, position_bias.size(1) - query_length)
  100. position_bias_key_index = max(0, position_bias.size(2) - key_length)
  101. position_bias = position_bias[:, position_bias_query_index:, position_bias_key_index:]
  102. attention_scores = attention_scores + position_bias
  103. if attention_mask is not None:
  104. attention_scores = attention_scores.masked_fill(attention_mask, torch.finfo(query_states.dtype).min)
  105. # (batch_size, n_heads, seq_length, key_length)
  106. attn_weights = nn.functional.softmax(attention_scores.float(), dim=-1).to(value_states.dtype)
  107. attn_weights = nn.functional.dropout(attn_weights, p=self.attn_dropout_p, training=self.training)
  108. context_states = torch.matmul(attn_weights, value_states)
  109. context_states = context_states.permute(0, 2, 1, 3).contiguous().view(batch_size, seq_length, -1)
  110. attn_output = self.out_proj(context_states)
  111. return attn_output, attn_weights
  112. class MptMLP(nn.Module):
  113. def __init__(self, config: MptConfig):
  114. super().__init__()
  115. hidden_size = config.hidden_size
  116. self.up_proj = nn.Linear(hidden_size, 4 * hidden_size, bias=False)
  117. self.act = nn.GELU(approximate="none")
  118. self.down_proj = nn.Linear(4 * hidden_size, hidden_size, bias=False)
  119. self.hidden_dropout = config.attn_config.attn_pdrop
  120. def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
  121. hidden_states = self.act(self.up_proj(hidden_states))
  122. intermediate_output = self.down_proj(hidden_states)
  123. output = F.dropout(intermediate_output, p=self.hidden_dropout, training=self.training)
  124. output = output + residual
  125. return output
  126. class MptBlock(GradientCheckpointingLayer):
  127. def __init__(self, config: MptConfig, layer_idx: Optional[int] = None):
  128. super().__init__()
  129. hidden_size = config.hidden_size
  130. self.norm_1 = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
  131. # backward compatibility with weights on the Hub
  132. self.norm_1.bias = None
  133. self.num_heads = config.n_heads
  134. self.attn = MptAttention(config, layer_idx)
  135. self.norm_2 = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
  136. # backward compatibility with weights on the Hub
  137. self.norm_2.bias = None
  138. self.ffn = MptMLP(config)
  139. self.dropout_rate = config.attn_config.attn_pdrop
  140. self.resid_attn_dropout = nn.Dropout(self.dropout_rate)
  141. def forward(
  142. self,
  143. hidden_states: torch.Tensor,
  144. position_bias: torch.Tensor,
  145. attention_mask: torch.Tensor,
  146. layer_past: Optional[Cache] = None,
  147. use_cache: bool = False,
  148. output_attentions: bool = False,
  149. cache_position: Optional[torch.Tensor] = None,
  150. ):
  151. # hidden_states: [batch_size, seq_length, hidden_size]
  152. # Layer norm at the beginning of the transformer layer.
  153. layernorm_output = self.norm_1(hidden_states)
  154. residual = hidden_states
  155. # Self attention.
  156. attn_outputs, attn_weights = self.attn(
  157. layernorm_output,
  158. position_bias=position_bias,
  159. attention_mask=attention_mask,
  160. past_key_values=layer_past,
  161. cache_position=cache_position,
  162. )
  163. hidden_states = self.resid_attn_dropout(attn_outputs) + residual
  164. layernorm_output = self.norm_2(hidden_states)
  165. # Get residual
  166. residual = hidden_states
  167. # MLP.
  168. output = self.ffn(layernorm_output, residual)
  169. return output, attn_weights
  170. @auto_docstring
  171. class MptPreTrainedModel(PreTrainedModel):
  172. config: MptConfig
  173. base_model_prefix = "transformer"
  174. supports_gradient_checkpointing = True
  175. _no_split_modules = ["MptBlock"]
  176. _keys_to_ignore_on_load_missing = [r"lm_head.*."]
  177. def __init__(self, *inputs, **kwargs):
  178. super().__init__(*inputs, **kwargs)
  179. def _init_weights(self, module: nn.Module):
  180. """Initialize the weights."""
  181. if isinstance(module, nn.Linear):
  182. # Slightly different from the TF version which uses truncated_normal for initialization
  183. # cf https://github.com/pytorch/pytorch/pull/5617
  184. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  185. if module.bias is not None:
  186. module.bias.data.zero_()
  187. elif isinstance(module, nn.Embedding):
  188. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  189. if module.padding_idx is not None:
  190. module.weight.data[module.padding_idx].zero_()
  191. elif isinstance(module, LayerNorm):
  192. if module.bias is not None:
  193. module.bias.data.zero_()
  194. module.weight.data.fill_(1.0)
  195. @staticmethod
  196. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  197. def _convert_to_mpt_cache(
  198. past_key_values: tuple[tuple[torch.Tensor, torch.Tensor]],
  199. ) -> tuple[tuple[torch.Tensor, torch.Tensor]]:
  200. """
  201. Converts the cache to the format expected by Mpt, i.e. to tuple(tuple([batch_size * num_heads, ...]))
  202. """
  203. batch_size, num_heads, head_dim, seq_length = past_key_values[0][0].shape
  204. batch_size_times_num_heads = batch_size * num_heads
  205. # key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length]
  206. # value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
  207. return tuple(
  208. (
  209. layer_past[0].reshape(batch_size_times_num_heads, head_dim, seq_length),
  210. layer_past[1].reshape(batch_size_times_num_heads, seq_length, head_dim),
  211. )
  212. for layer_past in past_key_values
  213. )
  214. @auto_docstring
  215. class MptModel(MptPreTrainedModel):
  216. def __init__(self, config: MptConfig):
  217. super().__init__(config)
  218. self.hidden_size = config.hidden_size
  219. self.num_heads = config.n_heads
  220. # Embedding + LN Embedding
  221. self.wte = nn.Embedding(config.vocab_size, self.hidden_size)
  222. # Transformer blocks
  223. self.blocks = nn.ModuleList([MptBlock(config, layer_idx=i) for i in range(config.n_layers)])
  224. # Final Layer Norm
  225. self.norm_f = LayerNorm(self.hidden_size, eps=config.layer_norm_epsilon)
  226. # backward compatibility with weights on the Hub
  227. self.norm_f.bias = None
  228. self.gradient_checkpointing = False
  229. # Initialize weights and apply final processing
  230. self.post_init()
  231. def get_input_embeddings(self):
  232. return self.wte
  233. def build_mpt_alibi_tensor(self, num_heads, sequence_length, alibi_bias_max=8, device=None):
  234. return build_mpt_alibi_tensor(num_heads, sequence_length, alibi_bias_max, device)
  235. def set_input_embeddings(self, new_embeddings: torch.Tensor):
  236. self.wte = new_embeddings
  237. @auto_docstring
  238. def forward(
  239. self,
  240. input_ids: Optional[torch.LongTensor] = None,
  241. past_key_values: Optional[Cache] = None,
  242. attention_mask: Optional[torch.Tensor] = None,
  243. inputs_embeds: Optional[torch.LongTensor] = None,
  244. use_cache: Optional[bool] = None,
  245. output_attentions: Optional[bool] = None,
  246. output_hidden_states: Optional[bool] = None,
  247. return_dict: Optional[bool] = None,
  248. cache_position: Optional[torch.Tensor] = None,
  249. **kwargs, # NOOP kwargs, for now
  250. ) -> Union[tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
  251. r"""
  252. input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
  253. `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()`
  254. (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.
  255. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
  256. `input_ids`.
  257. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  258. [`PreTrainedTokenizer.__call__`] for details.
  259. [What are input IDs?](../glossary#input-ids)
  260. """
  261. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  262. output_hidden_states = (
  263. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  264. )
  265. use_cache = use_cache if use_cache is not None else self.config.use_cache
  266. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  267. if input_ids is not None and inputs_embeds is not None:
  268. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  269. elif input_ids is not None:
  270. batch_size, seq_length = input_ids.shape
  271. elif inputs_embeds is not None:
  272. batch_size, seq_length, _ = inputs_embeds.shape
  273. else:
  274. raise ValueError("You have to specify either input_ids or inputs_embeds")
  275. if self.gradient_checkpointing and self.training:
  276. if use_cache:
  277. logger.warning_once(
  278. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  279. )
  280. use_cache = False
  281. if inputs_embeds is None:
  282. inputs_embeds = self.wte(input_ids)
  283. if use_cache and past_key_values is None:
  284. past_key_values = DynamicCache(config=self.config)
  285. if use_cache and isinstance(past_key_values, tuple):
  286. logger.warning_once(
  287. "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
  288. "You should pass an instance of `DynamicCache` instead, e.g. "
  289. "`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`."
  290. )
  291. past_key_values = DynamicCache.from_legacy_cache(past_key_values)
  292. hidden_states = inputs_embeds
  293. all_self_attentions = () if output_attentions else None
  294. all_hidden_states = () if output_hidden_states else None
  295. # Compute alibi tensor: check build_alibi_tensor documentation
  296. past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
  297. seq_length_with_past = seq_length + past_key_values_length
  298. if attention_mask is None:
  299. attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
  300. else:
  301. attention_mask = attention_mask.to(hidden_states.device)
  302. alibi = self.build_mpt_alibi_tensor(self.num_heads, self.config.max_seq_len, device=hidden_states.device)
  303. causal_mask = _prepare_4d_causal_attention_mask(
  304. attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
  305. )
  306. causal_mask = causal_mask.bool()
  307. for block in self.blocks:
  308. if output_hidden_states:
  309. all_hidden_states = all_hidden_states + (hidden_states,)
  310. outputs = block(
  311. hidden_states,
  312. layer_past=past_key_values,
  313. attention_mask=causal_mask,
  314. use_cache=use_cache,
  315. output_attentions=output_attentions,
  316. position_bias=alibi,
  317. cache_position=cache_position,
  318. )
  319. hidden_states = outputs[0]
  320. if output_attentions:
  321. all_self_attentions = all_self_attentions + (outputs[1],)
  322. # Add last hidden state
  323. hidden_states = self.norm_f(hidden_states)
  324. if output_hidden_states:
  325. all_hidden_states = all_hidden_states + (hidden_states,)
  326. if not return_dict:
  327. return tuple(
  328. v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions] if v is not None
  329. )
  330. return BaseModelOutputWithPastAndCrossAttentions(
  331. last_hidden_state=hidden_states,
  332. past_key_values=past_key_values,
  333. hidden_states=all_hidden_states,
  334. attentions=all_self_attentions,
  335. )
  336. @auto_docstring(
  337. custom_intro="""
  338. The MPT Model transformer with a language modeling head on top (linear layer with weights tied to the input
  339. embeddings).
  340. """
  341. )
  342. class MptForCausalLM(MptPreTrainedModel, GenerationMixin):
  343. _tied_weights_keys = ["lm_head.weight"]
  344. def __init__(self, config: MptConfig):
  345. super().__init__(config)
  346. self.transformer = MptModel(config)
  347. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  348. # Initialize weights and apply final processing
  349. self.post_init()
  350. def set_output_embeddings(self, new_embeddings: torch.Tensor):
  351. self.lm_head = new_embeddings
  352. @auto_docstring
  353. def forward(
  354. self,
  355. input_ids: Optional[torch.LongTensor] = None,
  356. past_key_values: Optional[Cache] = None,
  357. attention_mask: Optional[torch.Tensor] = None,
  358. inputs_embeds: Optional[torch.Tensor] = None,
  359. labels: Optional[torch.Tensor] = None,
  360. use_cache: Optional[bool] = None,
  361. output_attentions: Optional[bool] = None,
  362. output_hidden_states: Optional[bool] = None,
  363. return_dict: Optional[bool] = None,
  364. cache_position: Optional[torch.Tensor] = None,
  365. **kwargs,
  366. ) -> Union[tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
  367. r"""
  368. input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
  369. `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()`
  370. (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.
  371. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
  372. `input_ids`.
  373. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  374. [`PreTrainedTokenizer.__call__`] for details.
  375. [What are input IDs?](../glossary#input-ids)
  376. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  377. Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
  378. `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
  379. are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
  380. """
  381. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  382. transformer_outputs = self.transformer(
  383. input_ids,
  384. past_key_values=past_key_values,
  385. attention_mask=attention_mask,
  386. inputs_embeds=inputs_embeds,
  387. use_cache=use_cache,
  388. output_attentions=output_attentions,
  389. output_hidden_states=output_hidden_states,
  390. return_dict=return_dict,
  391. cache_position=cache_position,
  392. )
  393. hidden_states = transformer_outputs[0]
  394. lm_logits = self.lm_head(hidden_states)
  395. loss = None
  396. if labels is not None:
  397. # move labels to correct device to enable model parallelism
  398. labels = labels.to(lm_logits.device)
  399. # Flatten the tokens
  400. loss = self.loss_function(
  401. lm_logits,
  402. labels,
  403. vocab_size=self.config.vocab_size,
  404. **kwargs,
  405. )
  406. if not return_dict:
  407. output = (lm_logits,) + transformer_outputs[1:]
  408. return ((loss,) + output) if loss is not None else output
  409. return CausalLMOutputWithCrossAttentions(
  410. loss=loss,
  411. logits=lm_logits,
  412. past_key_values=transformer_outputs.past_key_values,
  413. hidden_states=transformer_outputs.hidden_states,
  414. attentions=transformer_outputs.attentions,
  415. )
  416. @auto_docstring(
  417. custom_intro="""
  418. The MPT Model transformer with a sequence classification head on top (linear layer).
  419. [`MptForSequenceClassification`] uses the last token in order to do the classification, as other causal models
  420. (e.g. GPT-1) do.
  421. Since it does classification on the last token, it requires to know the position of the last token. If a
  422. `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
  423. no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
  424. padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
  425. each row of the batch).
  426. """
  427. )
  428. class MptForSequenceClassification(MptPreTrainedModel):
  429. def __init__(self, config: MptConfig):
  430. super().__init__(config)
  431. self.num_labels = config.num_labels
  432. self.transformer = MptModel(config)
  433. self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
  434. # Initialize weights and apply final processing
  435. self.post_init()
  436. @auto_docstring
  437. def forward(
  438. self,
  439. input_ids: Optional[torch.LongTensor] = None,
  440. past_key_values: Optional[Cache] = None,
  441. attention_mask: Optional[torch.Tensor] = None,
  442. inputs_embeds: Optional[torch.Tensor] = None,
  443. labels: Optional[torch.Tensor] = None,
  444. use_cache: Optional[bool] = None,
  445. output_attentions: Optional[bool] = None,
  446. output_hidden_states: Optional[bool] = None,
  447. return_dict: Optional[bool] = None,
  448. ) -> Union[tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
  449. r"""
  450. input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
  451. `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()`
  452. (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.
  453. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
  454. `input_ids`.
  455. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  456. [`PreTrainedTokenizer.__call__`] for details.
  457. [What are input IDs?](../glossary#input-ids)
  458. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  459. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  460. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  461. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  462. """
  463. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  464. transformer_outputs = self.transformer(
  465. input_ids,
  466. past_key_values=past_key_values,
  467. attention_mask=attention_mask,
  468. inputs_embeds=inputs_embeds,
  469. use_cache=use_cache,
  470. output_attentions=output_attentions,
  471. output_hidden_states=output_hidden_states,
  472. return_dict=return_dict,
  473. )
  474. hidden_states = transformer_outputs[0]
  475. logits = self.score(hidden_states)
  476. if input_ids is not None:
  477. batch_size = input_ids.shape[0]
  478. else:
  479. batch_size = inputs_embeds.shape[0]
  480. if self.config.pad_token_id is None and batch_size != 1:
  481. raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
  482. if self.config.pad_token_id is None:
  483. last_non_pad_token = -1
  484. elif input_ids is not None:
  485. # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
  486. non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
  487. token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
  488. last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
  489. else:
  490. last_non_pad_token = -1
  491. logger.warning_once(
  492. f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
  493. "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
  494. )
  495. pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
  496. loss = None
  497. if labels is not None:
  498. if self.config.problem_type is None:
  499. if self.num_labels == 1:
  500. self.config.problem_type = "regression"
  501. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  502. self.config.problem_type = "single_label_classification"
  503. else:
  504. self.config.problem_type = "multi_label_classification"
  505. if self.config.problem_type == "regression":
  506. loss_fct = MSELoss()
  507. if self.num_labels == 1:
  508. loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
  509. else:
  510. loss = loss_fct(pooled_logits, labels)
  511. elif self.config.problem_type == "single_label_classification":
  512. loss_fct = CrossEntropyLoss()
  513. loss = loss_fct(pooled_logits, labels)
  514. elif self.config.problem_type == "multi_label_classification":
  515. loss_fct = BCEWithLogitsLoss()
  516. loss = loss_fct(pooled_logits, labels)
  517. if not return_dict:
  518. output = (pooled_logits,) + transformer_outputs[1:]
  519. return ((loss,) + output) if loss is not None else output
  520. return SequenceClassifierOutputWithPast(
  521. loss=loss,
  522. logits=pooled_logits,
  523. past_key_values=transformer_outputs.past_key_values,
  524. hidden_states=transformer_outputs.hidden_states,
  525. attentions=transformer_outputs.attentions,
  526. )
  527. @auto_docstring
  528. class MptForTokenClassification(MptPreTrainedModel):
  529. def __init__(self, config: MptConfig):
  530. super().__init__(config)
  531. self.num_labels = config.num_labels
  532. self.transformer = MptModel(config)
  533. if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
  534. classifier_dropout = config.classifier_dropout
  535. elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
  536. classifier_dropout = config.hidden_dropout
  537. else:
  538. classifier_dropout = 0.1
  539. self.dropout = nn.Dropout(classifier_dropout)
  540. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  541. # Initialize weights and apply final processing
  542. self.post_init()
  543. @auto_docstring
  544. def forward(
  545. self,
  546. input_ids: Optional[torch.LongTensor] = None,
  547. past_key_values: Optional[Cache] = None,
  548. attention_mask: Optional[torch.Tensor] = None,
  549. inputs_embeds: Optional[torch.Tensor] = None,
  550. labels: Optional[torch.Tensor] = None,
  551. use_cache: Optional[bool] = None,
  552. output_attentions: Optional[bool] = None,
  553. output_hidden_states: Optional[bool] = None,
  554. return_dict: Optional[bool] = None,
  555. **deprecated_arguments,
  556. ) -> Union[tuple[torch.Tensor], TokenClassifierOutput]:
  557. r"""
  558. input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
  559. `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()`
  560. (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.
  561. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
  562. `input_ids`.
  563. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  564. [`PreTrainedTokenizer.__call__`] for details.
  565. [What are input IDs?](../glossary#input-ids)
  566. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  567. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  568. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  569. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  570. """
  571. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  572. transformer_outputs = self.transformer(
  573. input_ids,
  574. past_key_values=past_key_values,
  575. attention_mask=attention_mask,
  576. inputs_embeds=inputs_embeds,
  577. use_cache=use_cache,
  578. output_attentions=output_attentions,
  579. output_hidden_states=output_hidden_states,
  580. return_dict=return_dict,
  581. )
  582. hidden_states = transformer_outputs[0]
  583. hidden_states = self.dropout(hidden_states)
  584. logits = self.classifier(hidden_states)
  585. loss = None
  586. if labels is not None:
  587. # move labels to correct device to enable model parallelism
  588. labels = labels.to(logits.device)
  589. batch_size, seq_length = labels.shape
  590. loss_fct = CrossEntropyLoss()
  591. loss = loss_fct(
  592. logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length)
  593. )
  594. if not return_dict:
  595. output = (logits,) + transformer_outputs[2:]
  596. return ((loss,) + output) if loss is not None else output
  597. return TokenClassifierOutput(
  598. loss=loss,
  599. logits=logits,
  600. hidden_states=transformer_outputs.hidden_states,
  601. attentions=transformer_outputs.attentions,
  602. )
  603. @auto_docstring
  604. class MptForQuestionAnswering(MptPreTrainedModel):
  605. def __init__(self, config):
  606. super().__init__(config)
  607. self.transformer = MptModel(config)
  608. self.qa_outputs = nn.Linear(config.hidden_size, 2)
  609. # Initialize weights and apply final processing
  610. self.post_init()
  611. @auto_docstring
  612. def forward(
  613. self,
  614. input_ids: Optional[torch.LongTensor] = None,
  615. attention_mask: Optional[torch.FloatTensor] = None,
  616. inputs_embeds: Optional[torch.FloatTensor] = None,
  617. start_positions: Optional[torch.LongTensor] = None,
  618. end_positions: Optional[torch.LongTensor] = None,
  619. output_attentions: Optional[bool] = None,
  620. output_hidden_states: Optional[bool] = None,
  621. return_dict: Optional[bool] = None,
  622. ) -> Union[tuple, QuestionAnsweringModelOutput]:
  623. r"""
  624. input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
  625. `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()`
  626. (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.
  627. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
  628. `input_ids`.
  629. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  630. [`PreTrainedTokenizer.__call__`] for details.
  631. [What are input IDs?](../glossary#input-ids)
  632. """
  633. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  634. outputs = self.transformer(
  635. input_ids,
  636. attention_mask=attention_mask,
  637. inputs_embeds=inputs_embeds,
  638. output_attentions=output_attentions,
  639. output_hidden_states=output_hidden_states,
  640. return_dict=return_dict,
  641. )
  642. sequence_output = outputs[0]
  643. logits = self.qa_outputs(sequence_output)
  644. start_logits, end_logits = logits.split(1, dim=-1)
  645. start_logits = start_logits.squeeze(-1).contiguous()
  646. end_logits = end_logits.squeeze(-1).contiguous()
  647. total_loss = None
  648. if start_positions is not None and end_positions is not None:
  649. # If we are on multi-GPU, split add a dimension
  650. if len(start_positions.size()) > 1:
  651. start_positions = start_positions.squeeze(-1)
  652. if len(end_positions.size()) > 1:
  653. end_positions = end_positions.squeeze(-1)
  654. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  655. ignored_index = start_logits.size(1)
  656. start_positions = start_positions.clamp(0, ignored_index)
  657. end_positions = end_positions.clamp(0, ignored_index)
  658. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  659. start_loss = loss_fct(start_logits, start_positions)
  660. end_loss = loss_fct(end_logits, end_positions)
  661. total_loss = (start_loss + end_loss) / 2
  662. if not return_dict:
  663. output = (start_logits, end_logits) + outputs[2:]
  664. return ((total_loss,) + output) if total_loss is not None else output
  665. return QuestionAnsweringModelOutput(
  666. loss=total_loss,
  667. start_logits=start_logits,
  668. end_logits=end_logits,
  669. hidden_states=outputs.hidden_states,
  670. attentions=outputs.attentions,
  671. )
  672. __all__ = [
  673. "MptForCausalLM",
  674. "MptModel",
  675. "MptPreTrainedModel",
  676. "MptForSequenceClassification",
  677. "MptForTokenClassification",
  678. "MptForQuestionAnswering",
  679. ]