modular_bitnet.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. # coding=utf-8
  2. # Copyright 2025 The BitNet Team and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. """PyTorch BitNet model."""
  15. from typing import Callable, Optional
  16. import torch
  17. from ...cache_utils import Cache
  18. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  19. from ...modeling_outputs import CausalLMOutputWithPast
  20. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
  21. from ...processing_utils import Unpack
  22. from ...utils import logging
  23. from ...utils.deprecation import deprecate_kwarg
  24. from ..gemma.modeling_gemma import GemmaMLP
  25. from ..llama.modeling_llama import (
  26. LlamaAttention,
  27. LlamaDecoderLayer,
  28. LlamaForCausalLM,
  29. LlamaModel,
  30. LlamaRMSNorm,
  31. apply_rotary_pos_emb,
  32. eager_attention_forward,
  33. )
  34. from .configuration_bitnet import BitNetConfig
  35. logger = logging.get_logger(__name__)
  36. class BitNetRMSNorm(LlamaRMSNorm):
  37. pass
  38. class BitNetMLP(GemmaMLP):
  39. def __init__(self, config: BitNetConfig):
  40. super().__init__(config)
  41. self.ffn_sub_norm = BitNetRMSNorm(config.intermediate_size, eps=config.rms_norm_eps)
  42. def forward(self, x):
  43. down_proj = self.down_proj(self.ffn_sub_norm(self.act_fn(self.gate_proj(x)) * self.up_proj(x)))
  44. return down_proj
  45. class BitNetAttention(LlamaAttention):
  46. def __init__(self, config: BitNetConfig, layer_idx: int):
  47. super().__init__(config, layer_idx)
  48. self.attn_sub_norm = BitNetRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  49. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  50. def forward(
  51. self,
  52. hidden_states: torch.Tensor,
  53. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  54. attention_mask: Optional[torch.Tensor],
  55. past_key_values: Optional[Cache] = None,
  56. cache_position: Optional[torch.LongTensor] = None,
  57. **kwargs: Unpack[FlashAttentionKwargs],
  58. ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
  59. input_shape = hidden_states.shape[:-1]
  60. hidden_shape = (*input_shape, -1, self.head_dim)
  61. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  62. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  63. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  64. cos, sin = position_embeddings
  65. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  66. if past_key_values is not None:
  67. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  68. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  69. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
  70. attention_interface: Callable = eager_attention_forward
  71. if self.config._attn_implementation != "eager":
  72. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  73. attn_output, attn_weights = attention_interface(
  74. self,
  75. query_states,
  76. key_states,
  77. value_states,
  78. attention_mask,
  79. dropout=0.0 if not self.training else self.attention_dropout,
  80. scaling=self.scaling,
  81. **kwargs,
  82. )
  83. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  84. attn_output = self.attn_sub_norm(attn_output) # diff with Llama
  85. attn_output = self.o_proj(attn_output)
  86. return attn_output, attn_weights
  87. class BitNetDecoderLayer(LlamaDecoderLayer):
  88. pass
  89. class BitNetModel(LlamaModel):
  90. pass
  91. class BitNetForCausalLM(LlamaForCausalLM):
  92. _tied_weights_keys = ["lm_head.weight"]
  93. _tp_plan = None
  94. _pp_plan = None
  95. def forward(
  96. self,
  97. **super_kwargs,
  98. ) -> CausalLMOutputWithPast:
  99. r"""
  100. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  101. Labels for computing the masked language modeling loss. Indices should either be in `[0, transformers.,
  102. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  103. (masked), the loss is only computed for the tokens with labels in `[0, transformers., config.vocab_size]`.
  104. Example:
  105. ```python
  106. >>> from transformers import AutoTokenizer, BitNetForCausalLM
  107. >>> model = BitNetForCausalLM.from_pretrained("microsoft/bitnet-b1.58-2B-4T")
  108. >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/bitnet-b1.58-2B-4T")
  109. >>> prompt = f'<|begin_of_text|>User: Hey, are you conscious? Can you talk to me?<|eot_id|>Assistant: '
  110. >>> inputs = tokenizer(prompt, return_tensors="pt")
  111. >>> # Generate
  112. >>> generate_ids = model.generate(inputs.input_ids, max_length=100)
  113. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  114. "User: Hey, are you conscious? Can you talk to me?Assistant: No, I'm not conscious. I'm an artificial intelligence designed to assist with information and tasks. How can I help you today?"
  115. ```"""
  116. return super().forward(**super_kwargs)
  117. __all__ = [
  118. "BitNetForCausalLM",
  119. "BitNetModel",
  120. "BitNetPreTrainedModel", # noqa: F822
  121. ]