modeling_cohere.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/cohere/modular_cohere.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_cohere.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # coding=utf-8
  8. # Copyright 2024 Cohere team. All rights reserved.
  9. #
  10. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  11. # and OPT implementations in this library. It has been modified from its
  12. # original forms to accommodate minor architectural differences compared
  13. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  14. #
  15. # Licensed under the Apache License, Version 2.0 (the "License");
  16. # you may not use this file except in compliance with the License.
  17. # You may obtain a copy of the License at
  18. #
  19. # http://www.apache.org/licenses/LICENSE-2.0
  20. #
  21. # Unless required by applicable law or agreed to in writing, software
  22. # distributed under the License is distributed on an "AS IS" BASIS,
  23. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  24. # See the License for the specific language governing permissions and
  25. # limitations under the License.
  26. # This file is based on the LLama model definition file in transformers
  27. from typing import Callable, Optional, Union
  28. import torch
  29. from torch import nn
  30. from ...activations import ACT2FN
  31. from ...cache_utils import Cache, DynamicCache
  32. from ...generation import GenerationMixin
  33. from ...masking_utils import create_causal_mask
  34. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  35. from ...modeling_layers import GradientCheckpointingLayer
  36. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  37. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  38. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  39. from ...processing_utils import Unpack
  40. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
  41. from ...utils.deprecation import deprecate_kwarg
  42. from ...utils.generic import check_model_inputs
  43. from .configuration_cohere import CohereConfig
  44. class CohereLayerNorm(nn.Module):
  45. def __init__(self, hidden_size=None, eps=1e-5, bias=False):
  46. """The hidden size can be a tuple or an int. The tuple is used for QKNorm to normalize across head_dim"""
  47. super().__init__()
  48. self.weight = nn.Parameter(torch.ones(hidden_size))
  49. self.variance_epsilon = eps
  50. def forward(self, hidden_states):
  51. input_dtype = hidden_states.dtype
  52. hidden_states = hidden_states.to(torch.float32)
  53. mean = hidden_states.mean(-1, keepdim=True)
  54. variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
  55. hidden_states = (hidden_states - mean) * torch.rsqrt(variance + self.variance_epsilon)
  56. hidden_states = self.weight.to(torch.float32) * hidden_states
  57. return hidden_states.to(input_dtype)
  58. class CohereRotaryEmbedding(nn.Module):
  59. inv_freq: torch.Tensor # fix linting for `register_buffer`
  60. def __init__(self, config: CohereConfig, device=None):
  61. super().__init__()
  62. # BC: "rope_type" was originally "type"
  63. if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
  64. self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
  65. else:
  66. self.rope_type = "default"
  67. self.max_seq_len_cached = config.max_position_embeddings
  68. self.original_max_seq_len = config.max_position_embeddings
  69. self.config = config
  70. self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  71. inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
  72. self.register_buffer("inv_freq", inv_freq, persistent=False)
  73. self.original_inv_freq = self.inv_freq
  74. @torch.no_grad()
  75. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  76. def forward(self, x, position_ids):
  77. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
  78. position_ids_expanded = position_ids[:, None, :].float()
  79. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  80. with torch.autocast(device_type=device_type, enabled=False): # Force float32
  81. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  82. emb = torch.repeat_interleave(freqs, 2, dim=-1) # diff from Llama: we interleave() instead of cat()
  83. cos = emb.cos() * self.attention_scaling
  84. sin = emb.sin() * self.attention_scaling
  85. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  86. class CohereMLP(nn.Module):
  87. def __init__(self, config):
  88. super().__init__()
  89. self.config = config
  90. self.hidden_size = config.hidden_size
  91. self.intermediate_size = config.intermediate_size
  92. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  93. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  94. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
  95. self.act_fn = ACT2FN[config.hidden_act]
  96. def forward(self, x):
  97. down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  98. return down_proj
  99. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  100. """
  101. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  102. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  103. """
  104. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  105. if n_rep == 1:
  106. return hidden_states
  107. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  108. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  109. def eager_attention_forward(
  110. module: nn.Module,
  111. query: torch.Tensor,
  112. key: torch.Tensor,
  113. value: torch.Tensor,
  114. attention_mask: Optional[torch.Tensor],
  115. scaling: float,
  116. dropout: float = 0.0,
  117. **kwargs: Unpack[TransformersKwargs],
  118. ):
  119. key_states = repeat_kv(key, module.num_key_value_groups)
  120. value_states = repeat_kv(value, module.num_key_value_groups)
  121. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  122. if attention_mask is not None:
  123. causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
  124. attn_weights = attn_weights + causal_mask
  125. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  126. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  127. attn_output = torch.matmul(attn_weights, value_states)
  128. attn_output = attn_output.transpose(1, 2).contiguous()
  129. return attn_output, attn_weights
  130. def rotate_half(x):
  131. # Split and rotate. Note that this function is different from e.g. Llama.
  132. x1 = x[..., ::2]
  133. x2 = x[..., 1::2]
  134. rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2)
  135. return rot_x
  136. def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
  137. """Applies Rotary Position Embedding to the query and key tensors.
  138. Args:
  139. q (`torch.Tensor`): The query tensor.
  140. k (`torch.Tensor`): The key tensor.
  141. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  142. sin (`torch.Tensor`): The sine part of the rotary embedding.
  143. position_ids (`torch.Tensor`, *optional*):
  144. Deprecated and unused.
  145. unsqueeze_dim (`int`, *optional*, defaults to 1):
  146. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  147. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  148. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  149. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  150. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  151. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  152. Returns:
  153. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  154. """
  155. dtype = q.dtype
  156. q = q.float()
  157. k = k.float()
  158. cos = cos.unsqueeze(unsqueeze_dim)
  159. sin = sin.unsqueeze(unsqueeze_dim)
  160. q_embed = (q * cos) + (rotate_half(q) * sin)
  161. k_embed = (k * cos) + (rotate_half(k) * sin)
  162. return q_embed.to(dtype=dtype), k_embed.to(dtype=dtype)
  163. class CohereAttention(nn.Module):
  164. """Multi-headed attention from 'Attention Is All You Need' paper"""
  165. def __init__(self, config: CohereConfig, layer_idx: Optional[int] = None):
  166. super().__init__()
  167. self.config = config
  168. self.layer_idx = layer_idx
  169. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  170. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  171. self.scaling = self.head_dim**-0.5
  172. self.attention_dropout = config.attention_dropout
  173. self.is_causal = True
  174. self.q_proj = nn.Linear(
  175. config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
  176. )
  177. self.k_proj = nn.Linear(
  178. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  179. )
  180. self.v_proj = nn.Linear(
  181. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  182. )
  183. self.o_proj = nn.Linear(
  184. config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
  185. )
  186. self.use_qk_norm = config.use_qk_norm
  187. if self.use_qk_norm:
  188. # When sharding the model using Tensor Parallelism, need to be careful to use n_local_heads
  189. self.q_norm = CohereLayerNorm(
  190. hidden_size=(config.num_attention_heads, self.head_dim), eps=config.layer_norm_eps
  191. )
  192. self.k_norm = CohereLayerNorm(
  193. hidden_size=(config.num_key_value_heads, self.head_dim), eps=config.layer_norm_eps
  194. )
  195. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  196. def forward(
  197. self,
  198. hidden_states: torch.Tensor,
  199. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  200. attention_mask: Optional[torch.Tensor],
  201. past_key_values: Optional[Cache] = None,
  202. cache_position: Optional[torch.LongTensor] = None,
  203. **kwargs: Unpack[FlashAttentionKwargs],
  204. ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
  205. input_shape = hidden_states.shape[:-1]
  206. hidden_shape = (*input_shape, -1, self.head_dim)
  207. query_states = self.q_proj(hidden_states).view(hidden_shape)
  208. key_states = self.k_proj(hidden_states).view(hidden_shape)
  209. value_states = self.v_proj(hidden_states).view(hidden_shape)
  210. if self.use_qk_norm: # main diff from Llama
  211. query_states = self.q_norm(query_states)
  212. key_states = self.k_norm(key_states)
  213. query_states = query_states.transpose(1, 2)
  214. key_states = key_states.transpose(1, 2)
  215. value_states = value_states.transpose(1, 2)
  216. cos, sin = position_embeddings
  217. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  218. if past_key_values is not None:
  219. # sin and cos are specific to RoPE models; position_ids needed for the static cache
  220. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  221. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
  222. attention_interface: Callable = eager_attention_forward
  223. if self.config._attn_implementation != "eager":
  224. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  225. attn_output, attn_weights = attention_interface(
  226. self,
  227. query_states,
  228. key_states,
  229. value_states,
  230. attention_mask,
  231. dropout=0.0 if not self.training else self.attention_dropout,
  232. scaling=self.scaling,
  233. **kwargs,
  234. )
  235. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  236. attn_output = self.o_proj(attn_output)
  237. return attn_output, attn_weights
  238. class CohereDecoderLayer(GradientCheckpointingLayer):
  239. def __init__(self, config: CohereConfig, layer_idx: int):
  240. super().__init__()
  241. self.hidden_size = config.hidden_size
  242. self.self_attn = CohereAttention(config=config, layer_idx=layer_idx)
  243. self.mlp = CohereMLP(config)
  244. self.input_layernorm = CohereLayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps)
  245. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  246. def forward(
  247. self,
  248. hidden_states: torch.Tensor,
  249. attention_mask: Optional[torch.Tensor] = None,
  250. position_ids: Optional[torch.LongTensor] = None,
  251. past_key_values: Optional[Cache] = None,
  252. use_cache: Optional[bool] = False,
  253. cache_position: Optional[torch.LongTensor] = None,
  254. position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
  255. **kwargs: Unpack[FlashAttentionKwargs],
  256. ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
  257. """
  258. Args:
  259. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  260. attention_mask (`torch.FloatTensor`, *optional*):
  261. attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
  262. query_sequence_length, key_sequence_length)` if default attention is used.
  263. past_key_values (`Cache`, *optional*): cached past key and value projection states
  264. output_attentions (`bool`, *optional*):
  265. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  266. returned tensors for more detail.
  267. use_cache (`bool`, *optional*):
  268. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  269. (see `past_key_values`).
  270. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
  271. Indices depicting the position of the input sequence tokens in the sequence
  272. position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
  273. Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
  274. with `head_dim` being the embedding dimension of each attention head.
  275. """
  276. residual = hidden_states
  277. hidden_states = self.input_layernorm(hidden_states)
  278. hidden_states_attention, _ = self.self_attn(
  279. hidden_states=hidden_states,
  280. attention_mask=attention_mask,
  281. position_ids=position_ids,
  282. past_key_values=past_key_values,
  283. use_cache=use_cache,
  284. cache_position=cache_position,
  285. position_embeddings=position_embeddings,
  286. **kwargs,
  287. )
  288. hidden_states_mlp = self.mlp(hidden_states)
  289. hidden_states = residual + hidden_states_attention + hidden_states_mlp
  290. return hidden_states
  291. @auto_docstring
  292. class CoherePreTrainedModel(PreTrainedModel):
  293. config: CohereConfig
  294. base_model_prefix = "model"
  295. supports_gradient_checkpointing = True
  296. _no_split_modules = ["CohereDecoderLayer"]
  297. _skip_keys_device_placement = ["past_key_values"]
  298. _supports_flash_attn = True
  299. _supports_sdpa = True
  300. _supports_flex_attn = True
  301. _can_compile_fullgraph = True
  302. _supports_attention_backend = True
  303. _can_record_outputs = {
  304. "hidden_states": CohereDecoderLayer,
  305. "attentions": CohereAttention,
  306. }
  307. @auto_docstring
  308. class CohereModel(CoherePreTrainedModel):
  309. def __init__(self, config: CohereConfig):
  310. super().__init__(config)
  311. self.padding_idx = config.pad_token_id
  312. self.vocab_size = config.vocab_size
  313. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  314. self.layers = nn.ModuleList(
  315. [CohereDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  316. )
  317. self.norm = CohereLayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps)
  318. self.rotary_emb = CohereRotaryEmbedding(config=config)
  319. self.gradient_checkpointing = False
  320. # Initialize weights and apply final processing
  321. self.post_init()
  322. @check_model_inputs()
  323. @auto_docstring
  324. def forward(
  325. self,
  326. input_ids: Optional[torch.LongTensor] = None,
  327. attention_mask: Optional[torch.Tensor] = None,
  328. position_ids: Optional[torch.LongTensor] = None,
  329. past_key_values: Optional[Cache] = None,
  330. inputs_embeds: Optional[torch.FloatTensor] = None,
  331. cache_position: Optional[torch.LongTensor] = None,
  332. use_cache: Optional[bool] = None,
  333. **kwargs: Unpack[TransformersKwargs],
  334. ) -> BaseModelOutputWithPast:
  335. if (input_ids is None) ^ (inputs_embeds is not None):
  336. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  337. if inputs_embeds is None:
  338. inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
  339. if use_cache and past_key_values is None:
  340. past_key_values = DynamicCache(config=self.config)
  341. if cache_position is None:
  342. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  343. cache_position: torch.Tensor = torch.arange(
  344. past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
  345. )
  346. if position_ids is None:
  347. position_ids = cache_position.unsqueeze(0)
  348. causal_mask = create_causal_mask(
  349. config=self.config,
  350. input_embeds=inputs_embeds,
  351. attention_mask=attention_mask,
  352. cache_position=cache_position,
  353. past_key_values=past_key_values,
  354. position_ids=position_ids,
  355. )
  356. hidden_states = inputs_embeds
  357. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  358. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  359. hidden_states = decoder_layer(
  360. hidden_states,
  361. attention_mask=causal_mask,
  362. position_ids=position_ids,
  363. past_key_values=past_key_values,
  364. cache_position=cache_position,
  365. position_embeddings=position_embeddings,
  366. **kwargs,
  367. )
  368. hidden_states = self.norm(hidden_states)
  369. return BaseModelOutputWithPast(
  370. last_hidden_state=hidden_states,
  371. past_key_values=past_key_values,
  372. )
  373. @auto_docstring
  374. class CohereForCausalLM(CoherePreTrainedModel, GenerationMixin):
  375. _tied_weights_keys = ["lm_head.weight"]
  376. _tp_plan = {"lm_head": "colwise_rep"}
  377. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  378. def __init__(self, config):
  379. super().__init__(config)
  380. self.model = CohereModel(config)
  381. self.vocab_size = config.vocab_size
  382. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  383. self.logit_scale = config.logit_scale
  384. self.tie_word_embeddings = config.tie_word_embeddings
  385. # Initialize weights and apply final processing
  386. self.post_init()
  387. @can_return_tuple
  388. @auto_docstring
  389. def forward(
  390. self,
  391. input_ids: Optional[torch.LongTensor] = None,
  392. attention_mask: Optional[torch.Tensor] = None,
  393. position_ids: Optional[torch.LongTensor] = None,
  394. past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
  395. inputs_embeds: Optional[torch.FloatTensor] = None,
  396. labels: Optional[torch.LongTensor] = None,
  397. use_cache: Optional[bool] = None,
  398. output_attentions: Optional[bool] = None,
  399. output_hidden_states: Optional[bool] = None,
  400. cache_position: Optional[torch.LongTensor] = None,
  401. logits_to_keep: Union[int, torch.Tensor] = 0,
  402. **kwargs: Unpack[TransformersKwargs],
  403. ) -> CausalLMOutputWithPast:
  404. r"""
  405. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  406. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  407. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  408. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  409. Example:
  410. ```python
  411. >> from transformers import AutoTokenizer, CohereForCausalLM
  412. >> model = CohereForCausalLM.from_pretrained("CohereForAI/c4ai-command-r-v01")
  413. >> tokenizer = AutoTokenizer.from_pretrained("CohereForAI/c4ai-command-r-v01")
  414. >> prompt = "Hey, are you conscious? Can you talk to me?"
  415. >> inputs = tokenizer(prompt, return_tensors="pt")
  416. >> # Generate
  417. >> generate_ids = model.generate(inputs.input_ids, max_length=30)
  418. >> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  419. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  420. ```"""
  421. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  422. output_hidden_states = (
  423. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  424. )
  425. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  426. outputs: BaseModelOutputWithPast = self.model(
  427. input_ids=input_ids,
  428. attention_mask=attention_mask,
  429. position_ids=position_ids,
  430. past_key_values=past_key_values,
  431. inputs_embeds=inputs_embeds,
  432. use_cache=use_cache,
  433. output_attentions=output_attentions,
  434. output_hidden_states=output_hidden_states,
  435. cache_position=cache_position,
  436. **kwargs,
  437. )
  438. hidden_states = outputs.last_hidden_state
  439. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  440. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  441. logits = self.lm_head(hidden_states[:, slice_indices, :])
  442. logits = logits * self.logit_scale # main diff from Llama
  443. loss = None
  444. if labels is not None:
  445. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  446. return CausalLMOutputWithPast(
  447. loss=loss,
  448. logits=logits,
  449. past_key_values=outputs.past_key_values,
  450. hidden_states=outputs.hidden_states,
  451. attentions=outputs.attentions,
  452. )
  453. __all__ = ["CohereForCausalLM", "CohereModel", "CoherePreTrainedModel"]