modular_cohere.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355
  1. # coding=utf-8
  2. # Copyright 2024 Cohere team. All rights reserved.
  3. #
  4. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  5. # and OPT implementations in this library. It has been modified from its
  6. # original forms to accommodate minor architectural differences compared
  7. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. # This file is based on the LLama model definition file in transformers
  21. """PyTorch Cohere model."""
  22. from typing import Callable, Optional, Union
  23. import torch
  24. from torch import nn
  25. from ...cache_utils import Cache
  26. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  27. from ...modeling_layers import GradientCheckpointingLayer
  28. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  29. from ...modeling_rope_utils import dynamic_rope_update
  30. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
  31. from ...processing_utils import Unpack
  32. from ...utils import TransformersKwargs, logging
  33. from ...utils.deprecation import deprecate_kwarg
  34. from ..llama.modeling_llama import (
  35. LlamaAttention,
  36. LlamaForCausalLM,
  37. LlamaMLP,
  38. LlamaModel,
  39. LlamaRotaryEmbedding,
  40. eager_attention_forward,
  41. )
  42. from .configuration_cohere import CohereConfig
  43. logger = logging.get_logger(__name__)
  44. class CohereLayerNorm(nn.Module):
  45. def __init__(self, hidden_size=None, eps=1e-5, bias=False):
  46. """The hidden size can be a tuple or an int. The tuple is used for QKNorm to normalize across head_dim"""
  47. super().__init__()
  48. self.weight = nn.Parameter(torch.ones(hidden_size))
  49. self.variance_epsilon = eps
  50. def forward(self, hidden_states):
  51. input_dtype = hidden_states.dtype
  52. hidden_states = hidden_states.to(torch.float32)
  53. mean = hidden_states.mean(-1, keepdim=True)
  54. variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
  55. hidden_states = (hidden_states - mean) * torch.rsqrt(variance + self.variance_epsilon)
  56. hidden_states = self.weight.to(torch.float32) * hidden_states
  57. return hidden_states.to(input_dtype)
  58. class CohereRotaryEmbedding(LlamaRotaryEmbedding):
  59. @torch.no_grad()
  60. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  61. def forward(self, x, position_ids):
  62. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
  63. position_ids_expanded = position_ids[:, None, :].float()
  64. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  65. with torch.autocast(device_type=device_type, enabled=False): # Force float32
  66. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  67. emb = torch.repeat_interleave(freqs, 2, dim=-1) # diff from Llama: we interleave() instead of cat()
  68. cos = emb.cos() * self.attention_scaling
  69. sin = emb.sin() * self.attention_scaling
  70. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  71. def rotate_half(x):
  72. # Split and rotate. Note that this function is different from e.g. Llama.
  73. x1 = x[..., ::2]
  74. x2 = x[..., 1::2]
  75. rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2)
  76. return rot_x
  77. def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
  78. """Applies Rotary Position Embedding to the query and key tensors.
  79. Args:
  80. q (`torch.Tensor`): The query tensor.
  81. k (`torch.Tensor`): The key tensor.
  82. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  83. sin (`torch.Tensor`): The sine part of the rotary embedding.
  84. position_ids (`torch.Tensor`, *optional*):
  85. Deprecated and unused.
  86. unsqueeze_dim (`int`, *optional*, defaults to 1):
  87. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  88. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  89. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  90. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  91. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  92. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  93. Returns:
  94. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  95. """
  96. dtype = q.dtype
  97. q = q.float()
  98. k = k.float()
  99. cos = cos.unsqueeze(unsqueeze_dim)
  100. sin = sin.unsqueeze(unsqueeze_dim)
  101. q_embed = (q * cos) + (rotate_half(q) * sin)
  102. k_embed = (k * cos) + (rotate_half(k) * sin)
  103. return q_embed.to(dtype=dtype), k_embed.to(dtype=dtype)
  104. class CohereMLP(LlamaMLP):
  105. def __init__(self, config):
  106. super().__init__(config)
  107. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  108. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  109. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
  110. class CohereAttention(LlamaAttention):
  111. """Multi-headed attention from 'Attention Is All You Need' paper"""
  112. def __init__(self, config: CohereConfig, layer_idx: Optional[int] = None):
  113. super().__init__(config, layer_idx)
  114. self.use_qk_norm = config.use_qk_norm
  115. if self.use_qk_norm:
  116. # When sharding the model using Tensor Parallelism, need to be careful to use n_local_heads
  117. self.q_norm = CohereLayerNorm(
  118. hidden_size=(config.num_attention_heads, self.head_dim), eps=config.layer_norm_eps
  119. )
  120. self.k_norm = CohereLayerNorm(
  121. hidden_size=(config.num_key_value_heads, self.head_dim), eps=config.layer_norm_eps
  122. )
  123. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  124. def forward(
  125. self,
  126. hidden_states: torch.Tensor,
  127. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  128. attention_mask: Optional[torch.Tensor],
  129. past_key_values: Optional[Cache] = None,
  130. cache_position: Optional[torch.LongTensor] = None,
  131. **kwargs: Unpack[FlashAttentionKwargs],
  132. ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
  133. input_shape = hidden_states.shape[:-1]
  134. hidden_shape = (*input_shape, -1, self.head_dim)
  135. query_states = self.q_proj(hidden_states).view(hidden_shape)
  136. key_states = self.k_proj(hidden_states).view(hidden_shape)
  137. value_states = self.v_proj(hidden_states).view(hidden_shape)
  138. if self.use_qk_norm: # main diff from Llama
  139. query_states = self.q_norm(query_states)
  140. key_states = self.k_norm(key_states)
  141. query_states = query_states.transpose(1, 2)
  142. key_states = key_states.transpose(1, 2)
  143. value_states = value_states.transpose(1, 2)
  144. cos, sin = position_embeddings
  145. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  146. if past_key_values is not None:
  147. # sin and cos are specific to RoPE models; position_ids needed for the static cache
  148. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  149. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
  150. attention_interface: Callable = eager_attention_forward
  151. if self.config._attn_implementation != "eager":
  152. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  153. attn_output, attn_weights = attention_interface(
  154. self,
  155. query_states,
  156. key_states,
  157. value_states,
  158. attention_mask,
  159. dropout=0.0 if not self.training else self.attention_dropout,
  160. scaling=self.scaling,
  161. **kwargs,
  162. )
  163. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  164. attn_output = self.o_proj(attn_output)
  165. return attn_output, attn_weights
  166. class CohereDecoderLayer(GradientCheckpointingLayer):
  167. def __init__(self, config: CohereConfig, layer_idx: int):
  168. super().__init__()
  169. self.hidden_size = config.hidden_size
  170. self.self_attn = CohereAttention(config=config, layer_idx=layer_idx)
  171. self.mlp = CohereMLP(config)
  172. self.input_layernorm = CohereLayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps)
  173. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  174. def forward(
  175. self,
  176. hidden_states: torch.Tensor,
  177. attention_mask: Optional[torch.Tensor] = None,
  178. position_ids: Optional[torch.LongTensor] = None,
  179. past_key_values: Optional[Cache] = None,
  180. use_cache: Optional[bool] = False,
  181. cache_position: Optional[torch.LongTensor] = None,
  182. position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
  183. **kwargs: Unpack[FlashAttentionKwargs],
  184. ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
  185. """
  186. Args:
  187. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  188. attention_mask (`torch.FloatTensor`, *optional*):
  189. attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
  190. query_sequence_length, key_sequence_length)` if default attention is used.
  191. past_key_values (`Cache`, *optional*): cached past key and value projection states
  192. output_attentions (`bool`, *optional*):
  193. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  194. returned tensors for more detail.
  195. use_cache (`bool`, *optional*):
  196. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  197. (see `past_key_values`).
  198. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
  199. Indices depicting the position of the input sequence tokens in the sequence
  200. position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
  201. Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
  202. with `head_dim` being the embedding dimension of each attention head.
  203. """
  204. residual = hidden_states
  205. hidden_states = self.input_layernorm(hidden_states)
  206. hidden_states_attention, _ = self.self_attn(
  207. hidden_states=hidden_states,
  208. attention_mask=attention_mask,
  209. position_ids=position_ids,
  210. past_key_values=past_key_values,
  211. use_cache=use_cache,
  212. cache_position=cache_position,
  213. position_embeddings=position_embeddings,
  214. **kwargs,
  215. )
  216. hidden_states_mlp = self.mlp(hidden_states)
  217. hidden_states = residual + hidden_states_attention + hidden_states_mlp
  218. return hidden_states
  219. class CohereModel(LlamaModel):
  220. def __init__(self, config: CohereConfig):
  221. super().__init__(config)
  222. self.layers = nn.ModuleList(
  223. [CohereDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  224. )
  225. self.rotary_emb = CohereRotaryEmbedding(config=config)
  226. self.norm = CohereLayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps)
  227. class CohereForCausalLM(LlamaForCausalLM):
  228. def __init__(self, config):
  229. super().__init__(config)
  230. self.model = CohereModel(config)
  231. self.logit_scale = config.logit_scale
  232. self.tie_word_embeddings = config.tie_word_embeddings
  233. def forward(
  234. self,
  235. input_ids: Optional[torch.LongTensor] = None,
  236. attention_mask: Optional[torch.Tensor] = None,
  237. position_ids: Optional[torch.LongTensor] = None,
  238. past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
  239. inputs_embeds: Optional[torch.FloatTensor] = None,
  240. labels: Optional[torch.LongTensor] = None,
  241. use_cache: Optional[bool] = None,
  242. output_attentions: Optional[bool] = None,
  243. output_hidden_states: Optional[bool] = None,
  244. cache_position: Optional[torch.LongTensor] = None,
  245. logits_to_keep: Union[int, torch.Tensor] = 0,
  246. **kwargs: Unpack[TransformersKwargs],
  247. ) -> CausalLMOutputWithPast:
  248. r"""
  249. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  250. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  251. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  252. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  253. Example:
  254. ```python
  255. >> from transformers import AutoTokenizer, CohereForCausalLM
  256. >> model = CohereForCausalLM.from_pretrained("CohereForAI/c4ai-command-r-v01")
  257. >> tokenizer = AutoTokenizer.from_pretrained("CohereForAI/c4ai-command-r-v01")
  258. >> prompt = "Hey, are you conscious? Can you talk to me?"
  259. >> inputs = tokenizer(prompt, return_tensors="pt")
  260. >> # Generate
  261. >> generate_ids = model.generate(inputs.input_ids, max_length=30)
  262. >> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  263. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  264. ```"""
  265. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  266. output_hidden_states = (
  267. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  268. )
  269. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  270. outputs: BaseModelOutputWithPast = self.model(
  271. input_ids=input_ids,
  272. attention_mask=attention_mask,
  273. position_ids=position_ids,
  274. past_key_values=past_key_values,
  275. inputs_embeds=inputs_embeds,
  276. use_cache=use_cache,
  277. output_attentions=output_attentions,
  278. output_hidden_states=output_hidden_states,
  279. cache_position=cache_position,
  280. **kwargs,
  281. )
  282. hidden_states = outputs.last_hidden_state
  283. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  284. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  285. logits = self.lm_head(hidden_states[:, slice_indices, :])
  286. logits = logits * self.logit_scale # main diff from Llama
  287. loss = None
  288. if labels is not None:
  289. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  290. return CausalLMOutputWithPast(
  291. loss=loss,
  292. logits=logits,
  293. past_key_values=outputs.past_key_values,
  294. hidden_states=outputs.hidden_states,
  295. attentions=outputs.attentions,
  296. )
  297. __all__ = [
  298. "CohereForCausalLM",
  299. "CohereModel",
  300. "CoherePreTrainedModel", # noqa: F822
  301. ]