modular_granite.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. # coding=utf-8
  2. # Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved.
  3. #
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. from typing import Optional, Union
  17. import torch
  18. from torch import nn
  19. from ...cache_utils import Cache, DynamicCache
  20. from ...masking_utils import create_causal_mask
  21. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  22. from ...processing_utils import Unpack
  23. from ...utils import TransformersKwargs, logging
  24. from ...utils.deprecation import deprecate_kwarg
  25. from ..llama.modeling_llama import (
  26. LlamaAttention,
  27. LlamaDecoderLayer,
  28. LlamaForCausalLM,
  29. LlamaModel,
  30. LlamaPreTrainedModel,
  31. )
  32. from .configuration_granite import GraniteConfig
  33. logger = logging.get_logger(__name__)
  34. class GraniteAttention(LlamaAttention):
  35. """Multi-headed attention from 'Attention Is All You Need' paper"""
  36. def __init__(self, config: GraniteConfig, layer_idx: Optional[int] = None):
  37. super().__init__(config, layer_idx)
  38. self.scaling = config.attention_multiplier
  39. class GraniteDecoderLayer(LlamaDecoderLayer):
  40. def __init__(self, config: GraniteConfig, layer_idx: int):
  41. super().__init__(config, layer_idx)
  42. self.residual_multiplier = config.residual_multiplier
  43. self.self_attn = GraniteAttention(config=config, layer_idx=layer_idx)
  44. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  45. def forward(
  46. self,
  47. hidden_states: torch.Tensor,
  48. attention_mask: Optional[torch.Tensor] = None,
  49. position_ids: Optional[torch.LongTensor] = None,
  50. past_key_values: Optional[Cache] = None,
  51. output_attentions: Optional[bool] = False,
  52. use_cache: Optional[bool] = False,
  53. cache_position: Optional[torch.LongTensor] = None,
  54. position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
  55. **kwargs,
  56. ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
  57. """
  58. Args:
  59. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  60. attention_mask (`torch.FloatTensor`, *optional*):
  61. attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
  62. query_sequence_length, key_sequence_length)` if default attention is used.
  63. output_attentions (`bool`, *optional*):
  64. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  65. returned tensors for more detail.
  66. use_cache (`bool`, *optional*):
  67. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  68. (see `past_key_values`).
  69. past_key_values (`Cache`, *optional*): cached past key and value projection states
  70. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
  71. Indices depicting the position of the input sequence tokens in the sequence
  72. position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
  73. Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
  74. with `head_dim` being the embedding dimension of each attention head.
  75. kwargs (`dict`, *optional*):
  76. Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
  77. into the model
  78. """
  79. residual = hidden_states
  80. hidden_states = self.input_layernorm(hidden_states)
  81. # Self Attention
  82. hidden_states, self_attn_weights = self.self_attn(
  83. hidden_states=hidden_states,
  84. attention_mask=attention_mask,
  85. position_ids=position_ids,
  86. past_key_values=past_key_values,
  87. output_attentions=output_attentions,
  88. use_cache=use_cache,
  89. cache_position=cache_position,
  90. position_embeddings=position_embeddings,
  91. **kwargs,
  92. )
  93. hidden_states = residual + hidden_states * self.residual_multiplier
  94. # Fully Connected
  95. residual = hidden_states
  96. hidden_states = self.post_attention_layernorm(hidden_states)
  97. hidden_states = self.mlp(hidden_states)
  98. hidden_states = residual + hidden_states * self.residual_multiplier # main diff with Llama
  99. outputs = (hidden_states,)
  100. if output_attentions:
  101. outputs += (self_attn_weights,)
  102. return outputs
  103. class GranitePreTrainedModel(LlamaPreTrainedModel):
  104. pass
  105. class GraniteModel(LlamaModel):
  106. def __init__(self, config: GraniteConfig):
  107. super().__init__(config)
  108. self.embedding_multiplier = config.embedding_multiplier
  109. self.layers = nn.ModuleList(
  110. [GraniteDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  111. )
  112. def forward(
  113. self,
  114. input_ids: Optional[torch.LongTensor] = None,
  115. attention_mask: Optional[torch.Tensor] = None,
  116. position_ids: Optional[torch.LongTensor] = None,
  117. past_key_values: Optional[Cache] = None,
  118. inputs_embeds: Optional[torch.FloatTensor] = None,
  119. use_cache: Optional[bool] = None,
  120. output_attentions: Optional[bool] = None,
  121. output_hidden_states: Optional[bool] = None,
  122. cache_position: Optional[torch.LongTensor] = None,
  123. **kwargs: Unpack[TransformersKwargs],
  124. ) -> BaseModelOutputWithPast:
  125. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  126. output_hidden_states = (
  127. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  128. )
  129. use_cache = use_cache if use_cache is not None else self.config.use_cache
  130. if (input_ids is None) ^ (inputs_embeds is not None):
  131. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  132. if self.gradient_checkpointing and self.training and use_cache:
  133. logger.warning_once(
  134. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
  135. )
  136. use_cache = False
  137. if inputs_embeds is None:
  138. inputs_embeds = self.embed_tokens(input_ids)
  139. inputs_embeds = inputs_embeds * self.embedding_multiplier # main diff with Llama
  140. if use_cache and past_key_values is None:
  141. past_key_values = DynamicCache(config=self.config)
  142. if cache_position is None:
  143. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  144. cache_position = torch.arange(
  145. past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
  146. )
  147. if position_ids is None:
  148. position_ids = cache_position.unsqueeze(0)
  149. causal_mask = create_causal_mask(
  150. config=self.config,
  151. input_embeds=inputs_embeds,
  152. attention_mask=attention_mask,
  153. cache_position=cache_position,
  154. past_key_values=past_key_values,
  155. position_ids=position_ids,
  156. )
  157. hidden_states = inputs_embeds
  158. # create position embeddings to be shared across the decoder layers
  159. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  160. # decoder layers
  161. all_hidden_states = () if output_hidden_states else None
  162. all_self_attns = () if output_attentions else None
  163. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  164. if output_hidden_states:
  165. all_hidden_states += (hidden_states,)
  166. layer_outputs = decoder_layer(
  167. hidden_states,
  168. attention_mask=causal_mask,
  169. position_ids=position_ids,
  170. past_key_values=past_key_values,
  171. output_attentions=output_attentions,
  172. use_cache=use_cache,
  173. cache_position=cache_position,
  174. position_embeddings=position_embeddings,
  175. **kwargs,
  176. )
  177. hidden_states = layer_outputs[0]
  178. if output_attentions:
  179. all_self_attns += (layer_outputs[1],)
  180. hidden_states = self.norm(hidden_states)
  181. # add hidden states from the last decoder layer
  182. if output_hidden_states:
  183. all_hidden_states += (hidden_states,)
  184. return BaseModelOutputWithPast(
  185. last_hidden_state=hidden_states,
  186. past_key_values=past_key_values if use_cache else None,
  187. hidden_states=all_hidden_states,
  188. attentions=all_self_attns,
  189. )
  190. class GraniteForCausalLM(LlamaForCausalLM):
  191. def forward(
  192. self,
  193. input_ids: Optional[torch.LongTensor] = None,
  194. attention_mask: Optional[torch.Tensor] = None,
  195. position_ids: Optional[torch.LongTensor] = None,
  196. past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
  197. inputs_embeds: Optional[torch.FloatTensor] = None,
  198. labels: Optional[torch.LongTensor] = None,
  199. use_cache: Optional[bool] = None,
  200. output_attentions: Optional[bool] = None,
  201. output_hidden_states: Optional[bool] = None,
  202. cache_position: Optional[torch.LongTensor] = None,
  203. logits_to_keep: Union[int, torch.Tensor] = 0,
  204. **kwargs: Unpack[TransformersKwargs],
  205. ) -> CausalLMOutputWithPast:
  206. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  207. output_hidden_states = (
  208. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  209. )
  210. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  211. outputs: BaseModelOutputWithPast = self.model(
  212. input_ids=input_ids,
  213. attention_mask=attention_mask,
  214. position_ids=position_ids,
  215. past_key_values=past_key_values,
  216. inputs_embeds=inputs_embeds,
  217. use_cache=use_cache,
  218. output_attentions=output_attentions,
  219. output_hidden_states=output_hidden_states,
  220. cache_position=cache_position,
  221. **kwargs,
  222. )
  223. hidden_states = outputs.last_hidden_state
  224. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  225. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  226. logits = self.lm_head(hidden_states[:, slice_indices, :])
  227. logits = logits / self.config.logits_scaling # main diff with Llama
  228. loss = None
  229. if labels is not None:
  230. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  231. return CausalLMOutputWithPast(
  232. loss=loss,
  233. logits=logits,
  234. past_key_values=outputs.past_key_values,
  235. hidden_states=outputs.hidden_states,
  236. attentions=outputs.attentions,
  237. )
  238. __all__ = ["GraniteForCausalLM", "GraniteModel", "GranitePreTrainedModel"]