modeling_cohere2.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/cohere2/modular_cohere2.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_cohere2.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # coding=utf-8
  8. # Copyright 2024 Cohere Inc. HuggingFace Inc. team. All rights reserved.
  9. #
  10. #
  11. # Licensed under the Apache License, Version 2.0 (the "License");
  12. # you may not use this file except in compliance with the License.
  13. # You may obtain a copy of the License at
  14. #
  15. # http://www.apache.org/licenses/LICENSE-2.0
  16. #
  17. # Unless required by applicable law or agreed to in writing, software
  18. # distributed under the License is distributed on an "AS IS" BASIS,
  19. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  20. # See the License for the specific language governing permissions and
  21. # limitations under the License.
  22. from typing import Callable, Optional, Union
  23. import torch
  24. import torch.nn as nn
  25. from ...activations import ACT2FN
  26. from ...cache_utils import Cache, DynamicCache
  27. from ...generation import GenerationMixin
  28. from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
  29. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  30. from ...modeling_layers import GradientCheckpointingLayer
  31. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  32. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  33. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  34. from ...processing_utils import Unpack
  35. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
  36. from ...utils.deprecation import deprecate_kwarg
  37. from ...utils.generic import check_model_inputs
  38. from .configuration_cohere2 import Cohere2Config
  39. class Cohere2RotaryEmbedding(nn.Module):
  40. inv_freq: torch.Tensor # fix linting for `register_buffer`
  41. def __init__(self, config: Cohere2Config, device=None):
  42. super().__init__()
  43. # BC: "rope_type" was originally "type"
  44. if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
  45. self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
  46. else:
  47. self.rope_type = "default"
  48. self.max_seq_len_cached = config.max_position_embeddings
  49. self.original_max_seq_len = config.max_position_embeddings
  50. self.config = config
  51. self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  52. inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
  53. self.register_buffer("inv_freq", inv_freq, persistent=False)
  54. self.original_inv_freq = self.inv_freq
  55. @torch.no_grad()
  56. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  57. def forward(self, x, position_ids):
  58. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
  59. position_ids_expanded = position_ids[:, None, :].float()
  60. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  61. with torch.autocast(device_type=device_type, enabled=False): # Force float32
  62. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  63. emb = torch.repeat_interleave(freqs, 2, dim=-1) # diff from Llama: we interleave() instead of cat()
  64. cos = emb.cos() * self.attention_scaling
  65. sin = emb.sin() * self.attention_scaling
  66. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  67. class Cohere2LayerNorm(nn.Module):
  68. def __init__(self, hidden_size=None, eps=1e-5, bias=False):
  69. """The hidden size can be a tuple or an int. The tuple is used for QKNorm to normalize across head_dim"""
  70. super().__init__()
  71. self.weight = nn.Parameter(torch.ones(hidden_size))
  72. self.variance_epsilon = eps
  73. def forward(self, hidden_states):
  74. input_dtype = hidden_states.dtype
  75. hidden_states = hidden_states.to(torch.float32)
  76. mean = hidden_states.mean(-1, keepdim=True)
  77. variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
  78. hidden_states = (hidden_states - mean) * torch.rsqrt(variance + self.variance_epsilon)
  79. hidden_states = self.weight.to(torch.float32) * hidden_states
  80. return hidden_states.to(input_dtype)
  81. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  82. """
  83. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  84. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  85. """
  86. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  87. if n_rep == 1:
  88. return hidden_states
  89. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  90. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  91. def eager_attention_forward(
  92. module: nn.Module,
  93. query: torch.Tensor,
  94. key: torch.Tensor,
  95. value: torch.Tensor,
  96. attention_mask: Optional[torch.Tensor],
  97. scaling: float,
  98. dropout: float = 0.0,
  99. **kwargs: Unpack[TransformersKwargs],
  100. ):
  101. key_states = repeat_kv(key, module.num_key_value_groups)
  102. value_states = repeat_kv(value, module.num_key_value_groups)
  103. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  104. if attention_mask is not None:
  105. causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
  106. attn_weights = attn_weights + causal_mask
  107. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  108. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  109. attn_output = torch.matmul(attn_weights, value_states)
  110. attn_output = attn_output.transpose(1, 2).contiguous()
  111. return attn_output, attn_weights
  112. def rotate_half(x):
  113. # Split and rotate. Note that this function is different from e.g. Llama.
  114. x1 = x[..., ::2]
  115. x2 = x[..., 1::2]
  116. rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2)
  117. return rot_x
  118. def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
  119. """Applies Rotary Position Embedding to the query and key tensors.
  120. Args:
  121. q (`torch.Tensor`): The query tensor.
  122. k (`torch.Tensor`): The key tensor.
  123. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  124. sin (`torch.Tensor`): The sine part of the rotary embedding.
  125. position_ids (`torch.Tensor`, *optional*):
  126. Deprecated and unused.
  127. unsqueeze_dim (`int`, *optional*, defaults to 1):
  128. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  129. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  130. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  131. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  132. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  133. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  134. Returns:
  135. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  136. """
  137. dtype = q.dtype
  138. q = q.float()
  139. k = k.float()
  140. cos = cos.unsqueeze(unsqueeze_dim)
  141. sin = sin.unsqueeze(unsqueeze_dim)
  142. q_embed = (q * cos) + (rotate_half(q) * sin)
  143. k_embed = (k * cos) + (rotate_half(k) * sin)
  144. return q_embed.to(dtype=dtype), k_embed.to(dtype=dtype)
  145. class Cohere2Attention(nn.Module):
  146. """Multi-headed attention from 'Attention Is All You Need' paper"""
  147. def __init__(self, config: Cohere2Config, layer_idx: Optional[int] = None):
  148. super().__init__()
  149. self.config = config
  150. self.layer_idx = layer_idx
  151. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  152. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  153. self.scaling = self.head_dim**-0.5
  154. self.attention_dropout = config.attention_dropout
  155. self.is_causal = True
  156. self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
  157. self.q_proj = nn.Linear(
  158. config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
  159. )
  160. self.k_proj = nn.Linear(
  161. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  162. )
  163. self.v_proj = nn.Linear(
  164. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  165. )
  166. self.o_proj = nn.Linear(
  167. config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
  168. )
  169. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  170. def forward(
  171. self,
  172. hidden_states: torch.Tensor,
  173. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  174. attention_mask: Optional[torch.Tensor],
  175. past_key_values: Optional[Cache] = None,
  176. cache_position: Optional[torch.LongTensor] = None,
  177. **kwargs: Unpack[FlashAttentionKwargs],
  178. ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
  179. input_shape = hidden_states.shape[:-1]
  180. hidden_shape = (*input_shape, -1, self.head_dim)
  181. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  182. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  183. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  184. cos, sin = position_embeddings
  185. if self.sliding_window is not None:
  186. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  187. if past_key_values is not None:
  188. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  189. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
  190. attention_interface: Callable = eager_attention_forward
  191. if self.config._attn_implementation != "eager":
  192. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  193. attn_output, attn_weights = attention_interface(
  194. self,
  195. query_states,
  196. key_states,
  197. value_states,
  198. attention_mask,
  199. dropout=0.0 if not self.training else self.attention_dropout,
  200. scaling=self.scaling,
  201. sliding_window=self.sliding_window,
  202. **kwargs,
  203. )
  204. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  205. attn_output = self.o_proj(attn_output)
  206. return attn_output, attn_weights
  207. class Cohere2MLP(nn.Module):
  208. def __init__(self, config):
  209. super().__init__()
  210. self.config = config
  211. self.hidden_size = config.hidden_size
  212. self.intermediate_size = config.intermediate_size
  213. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  214. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  215. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
  216. self.act_fn = ACT2FN[config.hidden_act]
  217. def forward(self, x):
  218. down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  219. return down_proj
  220. class Cohere2DecoderLayer(GradientCheckpointingLayer):
  221. def __init__(self, config: Cohere2Config, layer_idx: int):
  222. super().__init__()
  223. self.hidden_size = config.hidden_size
  224. self.self_attn = Cohere2Attention(config=config, layer_idx=layer_idx)
  225. self.mlp = Cohere2MLP(config)
  226. self.input_layernorm = Cohere2LayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps)
  227. self.attention_type = config.layer_types[layer_idx]
  228. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  229. def forward(
  230. self,
  231. hidden_states: torch.Tensor,
  232. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  233. attention_mask: Optional[torch.Tensor] = None,
  234. past_key_values: Optional[Cache] = None,
  235. use_cache: Optional[bool] = False,
  236. cache_position: Optional[torch.LongTensor] = None,
  237. **kwargs: Unpack[FlashAttentionKwargs],
  238. ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
  239. """
  240. Args:
  241. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  242. attention_mask (`torch.FloatTensor`, *optional*):
  243. attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
  244. query_sequence_length, key_sequence_length)` if default attention is used.
  245. past_key_values (`Cache`, *optional*): cached past key and value projection states
  246. output_attentions (`bool`, *optional*):
  247. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  248. returned tensors for more detail.
  249. use_cache (`bool`, *optional*):
  250. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  251. (see `past_key_values`).
  252. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
  253. Indices depicting the position of the input sequence tokens in the sequence
  254. position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
  255. Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
  256. with `head_dim` being the embedding dimension of each attention head.
  257. """
  258. residual = hidden_states
  259. hidden_states = self.input_layernorm(hidden_states)
  260. hidden_states_attention, _ = self.self_attn(
  261. hidden_states=hidden_states,
  262. position_embeddings=position_embeddings,
  263. attention_mask=attention_mask,
  264. past_key_values=past_key_values,
  265. use_cache=use_cache,
  266. cache_position=cache_position,
  267. **kwargs,
  268. )
  269. hidden_states_mlp = self.mlp(hidden_states)
  270. hidden_states = residual + hidden_states_attention + hidden_states_mlp
  271. return hidden_states
  272. @auto_docstring
  273. class Cohere2PreTrainedModel(PreTrainedModel):
  274. config: Cohere2Config
  275. base_model_prefix = "model"
  276. supports_gradient_checkpointing = True
  277. _no_split_modules = ["Cohere2DecoderLayer"]
  278. _skip_keys_device_placement = ["past_key_values"]
  279. _supports_flash_attn = True
  280. _supports_sdpa = True
  281. _supports_flex_attn = True
  282. _can_compile_fullgraph = True
  283. _supports_attention_backend = True
  284. _can_record_outputs = {
  285. "hidden_states": Cohere2DecoderLayer,
  286. "attentions": Cohere2Attention,
  287. }
  288. @auto_docstring
  289. class Cohere2Model(Cohere2PreTrainedModel):
  290. def __init__(self, config: Cohere2Config):
  291. super().__init__(config)
  292. self.padding_idx = config.pad_token_id
  293. self.vocab_size = config.vocab_size
  294. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  295. self.layers = nn.ModuleList(
  296. [Cohere2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  297. )
  298. self.norm = Cohere2LayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps)
  299. self.rotary_emb = Cohere2RotaryEmbedding(config=config)
  300. self.gradient_checkpointing = False
  301. # Initialize weights and apply final processing
  302. self.post_init()
  303. @check_model_inputs()
  304. @auto_docstring
  305. def forward(
  306. self,
  307. input_ids: Optional[torch.LongTensor] = None,
  308. attention_mask: Optional[torch.Tensor] = None,
  309. position_ids: Optional[torch.LongTensor] = None,
  310. past_key_values: Optional[Cache] = None,
  311. inputs_embeds: Optional[torch.FloatTensor] = None,
  312. use_cache: Optional[bool] = None,
  313. cache_position: Optional[torch.LongTensor] = None,
  314. **kwargs: Unpack[TransformersKwargs],
  315. ) -> BaseModelOutputWithPast:
  316. if (input_ids is None) ^ (inputs_embeds is not None):
  317. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  318. if inputs_embeds is None:
  319. inputs_embeds = self.embed_tokens(input_ids)
  320. if use_cache and past_key_values is None and not self.training:
  321. past_key_values = DynamicCache(config=self.config)
  322. if cache_position is None:
  323. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  324. cache_position = torch.arange(
  325. past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
  326. )
  327. if position_ids is None:
  328. position_ids = cache_position.unsqueeze(0)
  329. if not isinstance(causal_mask_mapping := attention_mask, dict):
  330. mask_kwargs = {
  331. "config": self.config,
  332. "input_embeds": inputs_embeds,
  333. "attention_mask": attention_mask,
  334. "cache_position": cache_position,
  335. "past_key_values": past_key_values,
  336. "position_ids": position_ids,
  337. }
  338. causal_mask_mapping = {
  339. "full_attention": create_causal_mask(**mask_kwargs),
  340. "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
  341. }
  342. hidden_states = inputs_embeds
  343. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  344. for decoder_layer in self.layers:
  345. hidden_states = decoder_layer(
  346. hidden_states,
  347. position_embeddings=position_embeddings,
  348. attention_mask=causal_mask_mapping[decoder_layer.attention_type],
  349. past_key_values=past_key_values,
  350. use_cache=use_cache,
  351. cache_position=cache_position,
  352. **kwargs,
  353. )
  354. hidden_states = self.norm(hidden_states)
  355. return BaseModelOutputWithPast(
  356. last_hidden_state=hidden_states,
  357. past_key_values=past_key_values,
  358. )
  359. @auto_docstring
  360. class Cohere2ForCausalLM(Cohere2PreTrainedModel, GenerationMixin):
  361. _tied_weights_keys = ["lm_head.weight"]
  362. _tp_plan = {"lm_head": "colwise_rep"}
  363. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  364. def __init__(self, config):
  365. super().__init__(config)
  366. self.model = Cohere2Model(config)
  367. self.vocab_size = config.vocab_size
  368. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  369. self.logit_scale = config.logit_scale
  370. self.tie_word_embeddings = config.tie_word_embeddings
  371. # Initialize weights and apply final processing
  372. self.post_init()
  373. @can_return_tuple
  374. @auto_docstring
  375. def forward(
  376. self,
  377. input_ids: Optional[torch.LongTensor] = None,
  378. attention_mask: Optional[torch.Tensor] = None,
  379. position_ids: Optional[torch.LongTensor] = None,
  380. past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
  381. inputs_embeds: Optional[torch.FloatTensor] = None,
  382. labels: Optional[torch.LongTensor] = None,
  383. use_cache: Optional[bool] = None,
  384. output_attentions: Optional[bool] = None,
  385. output_hidden_states: Optional[bool] = None,
  386. cache_position: Optional[torch.LongTensor] = None,
  387. logits_to_keep: Union[int, torch.Tensor] = 0,
  388. **kwargs: Unpack[TransformersKwargs],
  389. ) -> CausalLMOutputWithPast:
  390. r"""
  391. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  392. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  393. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  394. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  395. Example:
  396. ```python
  397. >> from transformers import AutoTokenizer, Cohere2ForCausalLM
  398. >> model = Cohere2ForCausalLM.from_pretrained("Cohere2ForAI/c4ai-command-r-v01")
  399. >> tokenizer = AutoTokenizer.from_pretrained("Cohere2ForAI/c4ai-command-r-v01")
  400. >> prompt = "Hey, are you conscious? Can you talk to me?"
  401. >> inputs = tokenizer(prompt, return_tensors="pt")
  402. >> # Generate
  403. >> generate_ids = model.generate(inputs.input_ids, max_length=30)
  404. >> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  405. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  406. ```"""
  407. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  408. output_hidden_states = (
  409. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  410. )
  411. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  412. outputs: BaseModelOutputWithPast = self.model(
  413. input_ids=input_ids,
  414. attention_mask=attention_mask,
  415. position_ids=position_ids,
  416. past_key_values=past_key_values,
  417. inputs_embeds=inputs_embeds,
  418. use_cache=use_cache,
  419. output_attentions=output_attentions,
  420. output_hidden_states=output_hidden_states,
  421. cache_position=cache_position,
  422. **kwargs,
  423. )
  424. hidden_states = outputs.last_hidden_state
  425. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  426. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  427. logits = self.lm_head(hidden_states[:, slice_indices, :])
  428. logits = logits * self.logit_scale # main diff from Llama
  429. loss = None
  430. if labels is not None:
  431. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  432. return CausalLMOutputWithPast(
  433. loss=loss,
  434. logits=logits,
  435. past_key_values=outputs.past_key_values,
  436. hidden_states=outputs.hidden_states,
  437. attentions=outputs.attentions,
  438. )
  439. __all__ = ["Cohere2ForCausalLM", "Cohere2Model", "Cohere2PreTrainedModel"]