modular_cohere2.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444
  1. # coding=utf-8
  2. # Copyright 2024 Cohere Inc. HuggingFace Inc. team. All rights reserved.
  3. #
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. from typing import Callable, Optional
  17. import torch
  18. import torch.nn as nn
  19. from ...cache_utils import Cache, DynamicCache
  20. from ...configuration_utils import PretrainedConfig, layer_type_validation
  21. from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
  22. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  23. from ...modeling_outputs import BaseModelOutputWithPast
  24. from ...modeling_rope_utils import rope_config_validation
  25. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
  26. from ...processing_utils import Unpack
  27. from ...utils import TransformersKwargs, logging
  28. from ...utils.deprecation import deprecate_kwarg
  29. from ..cohere.modeling_cohere import (
  30. CohereAttention,
  31. CohereDecoderLayer,
  32. CohereForCausalLM,
  33. CohereLayerNorm,
  34. CoherePreTrainedModel,
  35. CohereRotaryEmbedding,
  36. apply_rotary_pos_emb,
  37. eager_attention_forward,
  38. )
  39. from ..gemma2.modeling_gemma2 import Gemma2Model
  40. logger = logging.get_logger(__name__)
  41. class Cohere2Config(PretrainedConfig):
  42. r"""
  43. This is the configuration class to store the configuration of a [`CohereModel`]. It is used to instantiate an Cohere
  44. model according to the specified arguments, defining the model architecture.
  45. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  46. documentation from [`PretrainedConfig`] for more information. Instantiating a configuration
  47. with the defaults will yield a similar configuration to that of the [CohereForAI/c4ai-command-r-v01](https://huggingface.co/CohereForAI/c4ai-command-r-v01) model.
  48. Args:
  49. vocab_size (`int`, *optional*, defaults to 256000):
  50. Vocabulary size of the Cohere model. Defines the number of different tokens that can be represented by the
  51. `inputs_ids` passed when calling [`CohereModel`]
  52. hidden_size (`int`, *optional*, defaults to 8192):
  53. Dimension of the hidden representations.
  54. intermediate_size (`int`, *optional*, defaults to 22528):
  55. Dimension of the MLP representations.
  56. logit_scale (`float`, *optional*, defaults to 0.0625):
  57. The scaling factor for the output logits.
  58. num_hidden_layers (`int`, *optional*, defaults to 40):
  59. Number of hidden layers in the Transformer decoder.
  60. num_attention_heads (`int`, *optional*, defaults to 64):
  61. Number of attention heads for each attention layer in the Transformer decoder.
  62. num_key_value_heads (`int`, *optional*):
  63. This is the number of key_value heads that should be used to implement Grouped Query Attention. If
  64. `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
  65. `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
  66. converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
  67. by meanpooling all the original heads within that group. For more details, check out [this
  68. paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
  69. `num_attention_heads`.
  70. hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
  71. The non-linear activation function (function or string) in the decoder.
  72. max_position_embeddings (`int`, *optional*, defaults to 8192):
  73. The maximum sequence length that this model might ever be used with.
  74. initializer_range (`float`, *optional*, defaults to 0.02):
  75. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
  76. layer_norm_eps (`float`, *optional*, defaults to 1e-05):
  77. The epsilon used by the layer normalization.
  78. use_cache (`bool`, *optional*, defaults to `True`):
  79. Whether or not the model should return the last key/values attentions (not used by all models). Only
  80. relevant if `config.is_decoder=True`.
  81. pad_token_id (`int`, *optional*, defaults to 0):
  82. Padding token id.
  83. bos_token_id (`int`, *optional*, defaults to 5):
  84. Beginning of stream token id.
  85. eos_token_id (`int`, *optional*, defaults to 255001):
  86. End of stream token id.
  87. tie_word_embeddings (`bool`, *optional*, defaults to `True`):
  88. Whether to tie weight embeddings
  89. rope_theta (`float`, *optional*, defaults to 10000.0):
  90. The base period of the RoPE embeddings.
  91. rope_scaling (`Dict`, *optional*):
  92. Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
  93. and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
  94. accordingly.
  95. Expected contents:
  96. `rope_type` (`str`):
  97. The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
  98. 'llama3'], with 'default' being the original RoPE implementation.
  99. `factor` (`float`, *optional*):
  100. Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
  101. most scaling types, a `factor` of x will enable the model to handle sequences of length x *
  102. original maximum pre-trained length.
  103. `original_max_position_embeddings` (`int`, *optional*):
  104. Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
  105. pretraining.
  106. `attention_factor` (`float`, *optional*):
  107. Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
  108. computation. If unspecified, it defaults to value recommended by the implementation, using the
  109. `factor` field to infer the suggested value.
  110. `beta_fast` (`float`, *optional*):
  111. Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
  112. ramp function. If unspecified, it defaults to 32.
  113. `beta_slow` (`float`, *optional*):
  114. Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
  115. ramp function. If unspecified, it defaults to 1.
  116. `short_factor` (`list[float]`, *optional*):
  117. Only used with 'longrope'. The scaling factor to be applied to short contexts (<
  118. `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
  119. size divided by the number of attention heads divided by 2
  120. `long_factor` (`list[float]`, *optional*):
  121. Only used with 'longrope'. The scaling factor to be applied to long contexts (<
  122. `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
  123. size divided by the number of attention heads divided by 2
  124. `low_freq_factor` (`float`, *optional*):
  125. Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
  126. `high_freq_factor` (`float`, *optional*):
  127. Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
  128. attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
  129. Whether to use a bias in the query, key, value and output projection layers during self-attention.
  130. attention_dropout (`float`, *optional*, defaults to 0.0):
  131. The dropout ratio for the attention probabilities.
  132. sliding_window (`int`, *optional*, defaults to 4096):
  133. Size of the sliding window attention context.
  134. layer_types (`list`, *optional*):
  135. Attention pattern for each layer.
  136. ```python
  137. >>> from transformers import Cohere2Model, Cohere2Config
  138. >>> # Initializing a Cohere Nextmodel configuration
  139. >>> configuration = Cohere2Config()
  140. >>> # Initializing a model from the Cohere2 configuration
  141. >>> model = Cohere2Model(configuration) # doctest: +SKIP
  142. >>> # Accessing the model configuration
  143. >>> configuration = model.config # doctest: +SKIP
  144. ```
  145. """
  146. model_type = "cohere2"
  147. keys_to_ignore_at_inference = ["past_key_values"]
  148. base_model_tp_plan = {
  149. "layers.*.self_attn.q_proj": "colwise",
  150. "layers.*.self_attn.k_proj": "colwise",
  151. "layers.*.self_attn.v_proj": "colwise",
  152. "layers.*.self_attn.o_proj": "rowwise",
  153. "layers.*.mlp.gate_proj": "colwise",
  154. "layers.*.mlp.up_proj": "colwise",
  155. "layers.*.mlp.down_proj": "rowwise",
  156. }
  157. base_model_pp_plan = {
  158. "embed_tokens": (["input_ids"], ["inputs_embeds"]),
  159. "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
  160. "norm": (["hidden_states"], ["hidden_states"]),
  161. }
  162. def __init__(
  163. self,
  164. vocab_size=256000,
  165. hidden_size=8192,
  166. intermediate_size=22528,
  167. logit_scale=0.0625,
  168. num_hidden_layers=40,
  169. num_attention_heads=64,
  170. num_key_value_heads=None,
  171. hidden_act="silu",
  172. max_position_embeddings=8192,
  173. initializer_range=0.02,
  174. layer_norm_eps=1e-5,
  175. use_cache=True,
  176. pad_token_id=0,
  177. bos_token_id=5,
  178. eos_token_id=255001,
  179. tie_word_embeddings=True,
  180. rope_theta=10000.0,
  181. rope_scaling=None,
  182. attention_bias=False,
  183. attention_dropout=0.0,
  184. sliding_window=4096,
  185. layer_types=None,
  186. **kwargs,
  187. ):
  188. self.vocab_size = vocab_size
  189. self.max_position_embeddings = max_position_embeddings
  190. self.hidden_size = hidden_size
  191. self.logit_scale = logit_scale
  192. self.intermediate_size = intermediate_size
  193. self.num_hidden_layers = num_hidden_layers
  194. self.num_attention_heads = num_attention_heads
  195. # for backward compatibility
  196. if num_key_value_heads is None:
  197. num_key_value_heads = num_attention_heads
  198. self.num_key_value_heads = num_key_value_heads
  199. self.hidden_act = hidden_act
  200. self.initializer_range = initializer_range
  201. self.layer_norm_eps = layer_norm_eps
  202. self.use_cache = use_cache
  203. self.rope_theta = rope_theta
  204. self.rope_scaling = rope_scaling
  205. self.attention_bias = attention_bias
  206. self.attention_dropout = attention_dropout
  207. self.sliding_window = sliding_window
  208. self.layer_types = layer_types
  209. # Need to specify head_dim in the config so it can be used in the attention forward functions
  210. self.head_dim = hidden_size // num_attention_heads
  211. # Validate the correctness of rotary position embeddings parameters
  212. rope_config_validation(self)
  213. super().__init__(
  214. pad_token_id=pad_token_id,
  215. bos_token_id=bos_token_id,
  216. eos_token_id=eos_token_id,
  217. tie_word_embeddings=tie_word_embeddings,
  218. **kwargs,
  219. )
  220. # BC -> the pattern used to be a simple int, and it's still present in configs on the Hub
  221. self._sliding_window_pattern = kwargs.get("sliding_window_pattern", 4)
  222. if self.layer_types is None:
  223. # BC -> the pattern used to be a simple int, and it's still present in configs on the Hub
  224. self._sliding_window_pattern = getattr(self, "sliding_window_pattern", 4)
  225. self.layer_types = [
  226. "sliding_attention" if bool((i + 1) % self._sliding_window_pattern) else "full_attention"
  227. for i in range(self.num_hidden_layers)
  228. ]
  229. layer_type_validation(self.layer_types, self.num_hidden_layers)
  230. class Cohere2RotaryEmbedding(CohereRotaryEmbedding):
  231. pass
  232. class Cohere2LayerNorm(CohereLayerNorm):
  233. pass
  234. class Cohere2Attention(CohereAttention):
  235. """Multi-headed attention from 'Attention Is All You Need' paper"""
  236. def __init__(self, config: Cohere2Config, layer_idx: Optional[int] = None):
  237. nn.Module.__init__(self)
  238. self.config = config
  239. self.layer_idx = layer_idx
  240. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  241. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  242. self.scaling = self.head_dim**-0.5
  243. self.attention_dropout = config.attention_dropout
  244. self.is_causal = True
  245. self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
  246. self.q_proj = nn.Linear(
  247. config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
  248. )
  249. self.k_proj = nn.Linear(
  250. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  251. )
  252. self.v_proj = nn.Linear(
  253. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  254. )
  255. self.o_proj = nn.Linear(
  256. config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
  257. )
  258. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  259. def forward(
  260. self,
  261. hidden_states: torch.Tensor,
  262. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  263. attention_mask: Optional[torch.Tensor],
  264. past_key_values: Optional[Cache] = None,
  265. cache_position: Optional[torch.LongTensor] = None,
  266. **kwargs: Unpack[FlashAttentionKwargs],
  267. ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
  268. input_shape = hidden_states.shape[:-1]
  269. hidden_shape = (*input_shape, -1, self.head_dim)
  270. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  271. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  272. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  273. cos, sin = position_embeddings
  274. if self.sliding_window is not None:
  275. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  276. if past_key_values is not None:
  277. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  278. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
  279. attention_interface: Callable = eager_attention_forward
  280. if self.config._attn_implementation != "eager":
  281. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  282. attn_output, attn_weights = attention_interface(
  283. self,
  284. query_states,
  285. key_states,
  286. value_states,
  287. attention_mask,
  288. dropout=0.0 if not self.training else self.attention_dropout,
  289. scaling=self.scaling,
  290. sliding_window=self.sliding_window,
  291. **kwargs,
  292. )
  293. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  294. attn_output = self.o_proj(attn_output)
  295. return attn_output, attn_weights
  296. class Cohere2DecoderLayer(CohereDecoderLayer):
  297. def __init__(self, config: Cohere2Config, layer_idx: int):
  298. super().__init__(config, layer_idx)
  299. self.attention_type = config.layer_types[layer_idx]
  300. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  301. def forward(
  302. self,
  303. hidden_states: torch.Tensor,
  304. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  305. attention_mask: Optional[torch.Tensor] = None,
  306. past_key_values: Optional[Cache] = None,
  307. use_cache: Optional[bool] = False,
  308. cache_position: Optional[torch.LongTensor] = None,
  309. **kwargs: Unpack[FlashAttentionKwargs],
  310. ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
  311. residual = hidden_states
  312. hidden_states = self.input_layernorm(hidden_states)
  313. hidden_states_attention, _ = self.self_attn(
  314. hidden_states=hidden_states,
  315. position_embeddings=position_embeddings,
  316. attention_mask=attention_mask,
  317. past_key_values=past_key_values,
  318. use_cache=use_cache,
  319. cache_position=cache_position,
  320. **kwargs,
  321. )
  322. hidden_states_mlp = self.mlp(hidden_states)
  323. hidden_states = residual + hidden_states_attention + hidden_states_mlp
  324. return hidden_states
  325. class Cohere2PreTrainedModel(CoherePreTrainedModel):
  326. config: Cohere2Config
  327. class Cohere2Model(Gemma2Model):
  328. def __init__(self, config: Cohere2Config):
  329. super().__init__(config)
  330. self.norm = Cohere2LayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps)
  331. self.rotary_emb = Cohere2RotaryEmbedding(config=config)
  332. def forward(
  333. self,
  334. input_ids: Optional[torch.LongTensor] = None,
  335. attention_mask: Optional[torch.Tensor] = None,
  336. position_ids: Optional[torch.LongTensor] = None,
  337. past_key_values: Optional[Cache] = None,
  338. inputs_embeds: Optional[torch.FloatTensor] = None,
  339. use_cache: Optional[bool] = None,
  340. cache_position: Optional[torch.LongTensor] = None,
  341. **kwargs: Unpack[TransformersKwargs],
  342. ) -> BaseModelOutputWithPast:
  343. if (input_ids is None) ^ (inputs_embeds is not None):
  344. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  345. if inputs_embeds is None:
  346. inputs_embeds = self.embed_tokens(input_ids)
  347. if use_cache and past_key_values is None and not self.training:
  348. past_key_values = DynamicCache(config=self.config)
  349. if cache_position is None:
  350. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  351. cache_position = torch.arange(
  352. past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
  353. )
  354. if position_ids is None:
  355. position_ids = cache_position.unsqueeze(0)
  356. if not isinstance(causal_mask_mapping := attention_mask, dict):
  357. mask_kwargs = {
  358. "config": self.config,
  359. "input_embeds": inputs_embeds,
  360. "attention_mask": attention_mask,
  361. "cache_position": cache_position,
  362. "past_key_values": past_key_values,
  363. "position_ids": position_ids,
  364. }
  365. causal_mask_mapping = {
  366. "full_attention": create_causal_mask(**mask_kwargs),
  367. "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
  368. }
  369. hidden_states = inputs_embeds
  370. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  371. for decoder_layer in self.layers:
  372. hidden_states = decoder_layer(
  373. hidden_states,
  374. position_embeddings=position_embeddings,
  375. attention_mask=causal_mask_mapping[decoder_layer.attention_type],
  376. past_key_values=past_key_values,
  377. use_cache=use_cache,
  378. cache_position=cache_position,
  379. **kwargs,
  380. )
  381. hidden_states = self.norm(hidden_states)
  382. return BaseModelOutputWithPast(
  383. last_hidden_state=hidden_states,
  384. past_key_values=past_key_values,
  385. )
  386. class Cohere2ForCausalLM(CohereForCausalLM):
  387. pass
  388. __all__ = ["Cohere2Config", "Cohere2ForCausalLM", "Cohere2Model", "Cohere2PreTrainedModel"]