backbone.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846
  1. # Copyright (c) Alibaba Cloud.
  2. #
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import importlib
  6. import math
  7. from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, Union
  8. import torch
  9. import torch.nn.functional as F
  10. import torch.utils.checkpoint
  11. from torch import nn
  12. from torch.cuda.amp import autocast
  13. from torch.nn import CrossEntropyLoss
  14. from transformers import (GenerationConfig, PreTrainedTokenizer,
  15. StoppingCriteriaList)
  16. from transformers.generation.logits_process import LogitsProcessorList
  17. from transformers.generation.utils import GenerateOutput
  18. from transformers.modeling_outputs import (BaseModelOutputWithPast,
  19. CausalLMOutputWithPast)
  20. from transformers.modeling_utils import PreTrainedModel
  21. from transformers.trainer_utils import set_seed
  22. from transformers.utils import (ModelOutput, add_code_sample_docstrings,
  23. add_start_docstrings,
  24. add_start_docstrings_to_model_forward, logging)
  25. from transformers.utils.model_parallel_utils import (assert_device_map,
  26. get_device_map)
  27. from modelscope import Model, TorchModel
  28. from modelscope.metainfo import Models
  29. from modelscope.utils.constant import Tasks
  30. from modelscope.utils.logger import get_logger
  31. from ... import MODELS
  32. from .configuration import QWenConfig
  33. from .qwen_generation_utils import (HistoryType, StopWordsLogitsProcessor,
  34. decode_tokens, get_stop_words_ids,
  35. make_context)
  36. if TYPE_CHECKING:
  37. from transformers.generation.streamers import BaseStreamer
  38. try:
  39. from einops import rearrange
  40. except ImportError:
  41. rearrange = None
  42. try:
  43. from flash_attn.layers.rotary import apply_rotary_emb_func
  44. from einops import rearrange
  45. use_flash_rotary = True
  46. except ImportError:
  47. use_flash_rotary = False
  48. print(
  49. 'Warning: import flash_attn rotary fail, please install FlashAttention rotary to get better performance '
  50. 'https://github.com/Dao-AILab/flash-attention/tree/main/csrc/rotary')
  51. try:
  52. from flash_attn.ops.rms_norm import rms_norm
  53. except ImportError:
  54. rms_norm = None
  55. print(
  56. 'Warning: import flash_attn rms_norm fail, please install FlashAttention layer_norm to get better performance '
  57. 'https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm'
  58. )
  59. logger = get_logger()
  60. _CHECKPOINT_FOR_DOC = 'qwen-7b'
  61. _CONFIG_FOR_DOC = 'QWenConfig'
  62. QWen_PRETRAINED_MODEL_ARCHIVE_LIST = ['qwen-7b']
  63. try:
  64. from flash_attn.flash_attn_interface import flash_attn_unpadded_func
  65. except ImportError:
  66. flash_attn_unpadded_func = None
  67. print('Warning: import flash_attn fail, please install FlashAttention '
  68. 'https://github.com/Dao-AILab/flash-attention')
  69. class FlashSelfAttention(torch.nn.Module):
  70. def __init__(
  71. self,
  72. causal=False,
  73. softmax_scale=None,
  74. attention_dropout=0.0,
  75. ):
  76. super().__init__()
  77. assert flash_attn_unpadded_func is not None, (
  78. 'Please install FlashAttention first, '
  79. 'e.g., with pip install flash-attn')
  80. assert (rearrange is not None
  81. ), 'Please install einops first, e.g., with pip install einops'
  82. self.causal = causal
  83. self.softmax_scale = softmax_scale
  84. self.dropout_p = attention_dropout
  85. def forward(self, q, k, v):
  86. assert all(
  87. (i.dtype in [torch.float16, torch.bfloat16] for i in (q, k, v)))
  88. assert all((i.is_cuda for i in (q, k, v)))
  89. batch_size, seqlen_q = q.shape[0], q.shape[1]
  90. seqlen_k = k.shape[1]
  91. q, k, v = [rearrange(x, 'b s ... -> (b s) ...') for x in [q, k, v]]
  92. cu_seqlens_q = torch.arange(
  93. 0,
  94. (batch_size + 1) * seqlen_q,
  95. step=seqlen_q,
  96. dtype=torch.int32,
  97. device=q.device,
  98. )
  99. if self.training:
  100. assert seqlen_k == seqlen_q
  101. is_causal = self.causal
  102. cu_seqlens_k = cu_seqlens_q
  103. else:
  104. is_causal = seqlen_q == seqlen_k
  105. cu_seqlens_k = torch.arange(
  106. 0,
  107. (batch_size + 1) * seqlen_k,
  108. step=seqlen_k,
  109. dtype=torch.int32,
  110. device=q.device,
  111. )
  112. self.dropout_p = 0
  113. output = flash_attn_unpadded_func(
  114. q,
  115. k,
  116. v,
  117. cu_seqlens_q,
  118. cu_seqlens_k,
  119. seqlen_q,
  120. seqlen_k,
  121. self.dropout_p,
  122. softmax_scale=self.softmax_scale,
  123. causal=is_causal,
  124. )
  125. output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
  126. return output
  127. class QWenAttention(nn.Module):
  128. def __init__(self, config, layer_number=None):
  129. super().__init__()
  130. max_positions = config.max_position_embeddings
  131. self.register_buffer(
  132. 'bias',
  133. torch.tril(
  134. torch.ones((max_positions, max_positions),
  135. dtype=torch.bool)).view(1, 1, max_positions,
  136. max_positions),
  137. persistent=False,
  138. )
  139. self.register_buffer(
  140. 'masked_bias', torch.tensor(-1e4), persistent=False)
  141. self.layer_number = max(1, layer_number)
  142. self.params_dtype = config.params_dtype
  143. self.seq_length = config.seq_length
  144. self.hidden_size = config.hidden_size
  145. self.split_size = config.hidden_size
  146. self.num_heads = config.num_attention_heads
  147. self.head_dim = self.hidden_size // self.num_heads
  148. self.use_flash_attn = config.use_flash_attn
  149. self.scale_attn_weights = True
  150. self.layer_idx = None
  151. self.projection_size = config.kv_channels * config.num_attention_heads
  152. assert self.projection_size % config.num_attention_heads == 0
  153. self.hidden_size_per_attention_head = (
  154. self.projection_size // config.num_attention_heads)
  155. self.c_attn = nn.Linear(config.hidden_size, 3 * self.projection_size)
  156. self.c_proj = nn.Linear(
  157. config.hidden_size, self.projection_size, bias=not config.no_bias)
  158. self.is_fp32 = not (config.bf16 or config.fp16)
  159. if self.use_flash_attn and flash_attn_unpadded_func is not None and not self.is_fp32:
  160. self.core_attention_flash = FlashSelfAttention(
  161. causal=True, attention_dropout=config.attn_pdrop)
  162. self.bf16 = config.bf16
  163. if config.rotary_pct == 1.0:
  164. self.rotary_ndims = None
  165. else:
  166. assert config.rotary_pct < 1
  167. self.rotary_ndims = int(self.hidden_size_per_attention_head
  168. * config.rotary_pct)
  169. dim = (
  170. self.rotary_ndims if self.rotary_ndims is not None else
  171. self.hidden_size_per_attention_head)
  172. self.rotary_emb = RotaryEmbedding(dim, base=config.rotary_emb_base)
  173. self.use_dynamic_ntk = config.use_dynamic_ntk
  174. self.use_logn_attn = config.use_logn_attn
  175. logn_list = [
  176. math.log(i, self.seq_length) if i > self.seq_length else 1
  177. for i in range(1, 32768)
  178. ]
  179. self.logn_tensor = torch.Tensor(logn_list)[None, :, None, None]
  180. self._ntk_cached = 1.0
  181. self.attn_dropout = nn.Dropout(config.attn_pdrop)
  182. def _attn(self, query, key, value, attention_mask=None, head_mask=None):
  183. attn_weights = torch.matmul(query, key.transpose(-1, -2))
  184. if self.scale_attn_weights:
  185. attn_weights = attn_weights / torch.full(
  186. [],
  187. value.size(-1)**0.5,
  188. dtype=attn_weights.dtype,
  189. device=attn_weights.device,
  190. )
  191. query_length, key_length = query.size(-2), key.size(-2)
  192. causal_mask = self.bias[:, :, key_length
  193. - query_length:key_length, :key_length]
  194. mask_value = torch.finfo(attn_weights.dtype).min
  195. mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(
  196. attn_weights.device)
  197. attn_weights = torch.where(causal_mask,
  198. attn_weights.to(attn_weights.dtype),
  199. mask_value)
  200. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  201. attn_weights = attn_weights.type(value.dtype)
  202. attn_weights = self.attn_dropout(attn_weights)
  203. if head_mask is not None:
  204. attn_weights = attn_weights * head_mask
  205. attn_output = torch.matmul(attn_weights, value)
  206. attn_output = attn_output.transpose(1, 2)
  207. return attn_output, attn_weights
  208. def _upcast_and_reordered_attn(self,
  209. query,
  210. key,
  211. value,
  212. attention_mask=None,
  213. head_mask=None):
  214. bsz, num_heads, q_seq_len, dk = query.size()
  215. _, _, k_seq_len, _ = key.size()
  216. attn_weights = torch.empty(
  217. bsz * num_heads,
  218. q_seq_len,
  219. k_seq_len,
  220. dtype=torch.float32,
  221. device=query.device,
  222. )
  223. scale_factor = 1.0
  224. if self.scale_attn_weights:
  225. scale_factor /= float(value.size(-1))**0.5
  226. with autocast(enabled=False):
  227. q, k = query.reshape(-1, q_seq_len,
  228. dk), key.transpose(-1, -2).reshape(
  229. -1, dk, k_seq_len)
  230. attn_weights = torch.baddbmm(
  231. attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
  232. attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len,
  233. k_seq_len)
  234. query_length, key_length = query.size(-2), key.size(-2)
  235. causal_mask = self.bias[:, :, key_length
  236. - query_length:key_length, :key_length]
  237. mask_value = torch.finfo(attn_weights.dtype).min
  238. mask_value = torch.tensor(
  239. mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
  240. attn_weights = torch.where(causal_mask, attn_weights, mask_value)
  241. if attention_mask is not None:
  242. attn_weights = attn_weights + attention_mask
  243. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  244. if attn_weights.dtype != torch.float32:
  245. raise RuntimeError(
  246. 'Error with upcasting, attn_weights does not have dtype torch.float32'
  247. )
  248. attn_weights = attn_weights.type(value.dtype)
  249. attn_weights = self.attn_dropout(attn_weights)
  250. if head_mask is not None:
  251. attn_weights = attn_weights * head_mask
  252. attn_output = torch.matmul(attn_weights, value)
  253. return attn_output, attn_weights
  254. def _split_heads(self, tensor, num_heads, attn_head_size):
  255. new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
  256. tensor = tensor.view(new_shape)
  257. return tensor
  258. def _merge_heads(self, tensor, num_heads, attn_head_size):
  259. tensor = tensor.contiguous()
  260. new_shape = tensor.size()[:-2] + (num_heads * attn_head_size, )
  261. return tensor.view(new_shape)
  262. def forward(
  263. self,
  264. hidden_states: Optional[Tuple[torch.FloatTensor]],
  265. layer_past: Optional[Tuple[torch.Tensor]] = None,
  266. attention_mask: Optional[torch.FloatTensor] = None,
  267. head_mask: Optional[torch.FloatTensor] = None,
  268. encoder_hidden_states: Optional[torch.Tensor] = None,
  269. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  270. output_attentions: Optional[bool] = False,
  271. use_cache: Optional[bool] = False,
  272. ):
  273. mixed_x_layer = self.c_attn(hidden_states)
  274. query, key, value = mixed_x_layer.split(self.split_size, dim=2)
  275. query = self._split_heads(query, self.num_heads, self.head_dim)
  276. key = self._split_heads(key, self.num_heads, self.head_dim)
  277. value = self._split_heads(value, self.num_heads, self.head_dim)
  278. kv_seq_len = hidden_states.size()[1]
  279. if layer_past:
  280. kv_seq_len += layer_past[0].shape[1]
  281. if (self.use_dynamic_ntk and kv_seq_len == hidden_states.size()[1]
  282. and not self.training):
  283. context_value = math.log(kv_seq_len / self.seq_length, 2) + 1
  284. ntk_alpha = 2**math.ceil(context_value) - 1
  285. ntk_alpha = max(ntk_alpha, 1)
  286. self._ntk_cached = ntk_alpha
  287. else:
  288. ntk_alpha = self._ntk_cached
  289. rotary_pos_emb = self.rotary_emb(
  290. kv_seq_len, ntk_alpha=ntk_alpha).to(hidden_states.device)
  291. if rotary_pos_emb is not None:
  292. if isinstance(rotary_pos_emb, tuple):
  293. rotary_pos_emb = rotary_pos_emb
  294. else:
  295. rotary_pos_emb = (rotary_pos_emb, ) * 2
  296. if rotary_pos_emb is not None:
  297. q_pos_emb, k_pos_emb = rotary_pos_emb
  298. cur_len = query.shape[1]
  299. q_pos_emb = q_pos_emb[:, -cur_len:, :, :]
  300. k_pos_emb = k_pos_emb[:, -cur_len:, :, :]
  301. query = apply_rotary_pos_emb(query, q_pos_emb)
  302. key = apply_rotary_pos_emb(key, k_pos_emb)
  303. if layer_past is not None:
  304. past_key, past_value = layer_past[0], layer_past[1]
  305. key = torch.cat((past_key, key), dim=1)
  306. value = torch.cat((past_value, value), dim=1)
  307. if use_cache:
  308. present = (key, value)
  309. else:
  310. present = None
  311. if self.use_logn_attn and not self.training:
  312. if self.logn_tensor.device != query.device:
  313. self.logn_tensor = self.logn_tensor.to(
  314. query.device).type_as(query)
  315. seq_start = key.size(1) - query.size(1)
  316. seq_end = key.size(1)
  317. logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :]
  318. query = query * logn_tensor.expand_as(query)
  319. if self.use_flash_attn and flash_attn_unpadded_func is not None and not self.is_fp32 and query.is_cuda:
  320. q, k, v = query, key, value
  321. context_layer = self.core_attention_flash(q, k, v)
  322. context_layer = rearrange(context_layer,
  323. 'b s h d -> b s (h d)').contiguous()
  324. else:
  325. query = query.permute(0, 2, 1, 3)
  326. key = key.permute(0, 2, 1, 3)
  327. value = value.permute(0, 2, 1, 3)
  328. attn_output, attn_weight = self._attn(query, key, value,
  329. attention_mask, head_mask)
  330. context_layer = self._merge_heads(attn_output, self.num_heads,
  331. self.head_dim)
  332. attn_output = self.c_proj(context_layer)
  333. outputs = (attn_output, present)
  334. if output_attentions:
  335. if self.use_flash_attn and flash_attn_unpadded_func is not None and not self.is_fp32:
  336. raise ValueError(
  337. 'Cannot output attentions while using flash-attn')
  338. else:
  339. outputs += (attn_weight, )
  340. return outputs
  341. class QWenMLP(nn.Module):
  342. def __init__(self, config):
  343. super().__init__()
  344. self.w1 = nn.Linear(
  345. config.hidden_size,
  346. config.ffn_hidden_size // 2,
  347. bias=not config.no_bias)
  348. self.w2 = nn.Linear(
  349. config.hidden_size,
  350. config.ffn_hidden_size // 2,
  351. bias=not config.no_bias)
  352. ff_dim_in = config.ffn_hidden_size // 2
  353. self.c_proj = nn.Linear(
  354. ff_dim_in, config.hidden_size, bias=not config.no_bias)
  355. def forward(self, hidden_states):
  356. a1 = self.w1(hidden_states)
  357. a2 = self.w2(hidden_states)
  358. intermediate_parallel = a1 * F.silu(a2)
  359. output = self.c_proj(intermediate_parallel)
  360. return output
  361. class QWenBlock(nn.Module):
  362. def __init__(self, config, layer_idx=None, num_expert=1):
  363. super().__init__()
  364. self.num_expert = num_expert
  365. self.layer_number = layer_idx
  366. self.apply_residual_connection_post_layernorm = (
  367. config.apply_residual_connection_post_layernorm)
  368. hidden_size = config.hidden_size
  369. self.apply_residual_connection_post_layernorm = (
  370. config.apply_residual_connection_post_layernorm)
  371. self.bf16 = config.bf16
  372. self.ln_1 = RMSNorm(
  373. hidden_size,
  374. eps=config.layer_norm_epsilon,
  375. )
  376. self.attn = QWenAttention(config, layer_number=layer_idx)
  377. self.ln_2 = RMSNorm(
  378. hidden_size,
  379. eps=config.layer_norm_epsilon,
  380. )
  381. self.mlp = QWenMLP(config)
  382. def forward(
  383. self,
  384. hidden_states: Optional[Tuple[torch.FloatTensor]],
  385. layer_past: Optional[Tuple[torch.Tensor]] = None,
  386. attention_mask: Optional[torch.FloatTensor] = None,
  387. head_mask: Optional[torch.FloatTensor] = None,
  388. encoder_hidden_states: Optional[torch.Tensor] = None,
  389. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  390. use_cache: Optional[bool] = False,
  391. output_attentions: Optional[bool] = False,
  392. ):
  393. layernorm_output = self.ln_1(hidden_states)
  394. attn_outputs = self.attn(
  395. layernorm_output,
  396. layer_past=layer_past,
  397. attention_mask=attention_mask,
  398. head_mask=head_mask,
  399. use_cache=use_cache,
  400. output_attentions=output_attentions,
  401. )
  402. attn_output = attn_outputs[0]
  403. outputs = attn_outputs[1:]
  404. if self.apply_residual_connection_post_layernorm:
  405. residual = layernorm_output
  406. else:
  407. residual = hidden_states
  408. layernorm_input = attn_output + residual
  409. layernorm_output = self.ln_2(layernorm_input)
  410. if self.apply_residual_connection_post_layernorm:
  411. residual = layernorm_output
  412. else:
  413. residual = layernorm_input
  414. mlp_output = self.mlp(layernorm_output)
  415. hidden_states = residual + mlp_output
  416. if use_cache:
  417. outputs = (hidden_states, ) + outputs
  418. else:
  419. outputs = (hidden_states, ) + outputs[1:]
  420. return outputs
  421. class QWenPreTrainedModel(TorchModel, PreTrainedModel):
  422. config_class = QWenConfig
  423. base_model_prefix = 'transformer'
  424. is_parallelizable = False
  425. supports_gradient_checkpointing = True
  426. _no_split_modules = ['QWenBlock']
  427. def __init__(self, config, **kwargs):
  428. super().__init__(config.name_or_path, **kwargs)
  429. super(Model, self).__init__(config)
  430. def _init_weights(self, module):
  431. """Initialize the weights."""
  432. if isinstance(module, nn.Linear):
  433. module.weight.data.normal_(
  434. mean=0.0, std=self.config.initializer_range)
  435. if module.bias is not None:
  436. module.bias.data.zero_()
  437. elif isinstance(module, nn.Embedding):
  438. module.weight.data.normal_(
  439. mean=0.0, std=self.config.initializer_range)
  440. if module.padding_idx is not None:
  441. module.weight.data[module.padding_idx].zero_()
  442. elif isinstance(module, RMSNorm):
  443. module.weight.data.fill_(1.0)
  444. for name, p in module.named_parameters():
  445. if name == 'c_proj.weight':
  446. p.data.normal_(
  447. mean=0.0,
  448. std=(self.config.initializer_range
  449. / math.sqrt(2 * self.config.n_layer)),
  450. )
  451. def _set_gradient_checkpointing(self, module, value=False):
  452. if isinstance(module, QWenModel):
  453. module.gradient_checkpointing = value
  454. @classmethod
  455. def _instantiate(cls, **kwargs):
  456. model_dir = kwargs.pop('model_dir', None)
  457. if model_dir is None:
  458. config = QWenConfig(**kwargs)
  459. model = cls(config)
  460. else:
  461. model = super(Model, cls).from_pretrained(
  462. pretrained_model_name_or_path=model_dir, **kwargs)
  463. model.model_dir = model_dir
  464. return model
  465. @MODELS.register_module(Tasks.backbone, module_name=Models.qwen_7b)
  466. class QWenModel(QWenPreTrainedModel):
  467. _keys_to_ignore_on_load_missing = ['attn.masked_bias']
  468. def __init__(self, config):
  469. super().__init__(config)
  470. self.vocab_size = config.padded_vocab_size
  471. self.num_hidden_layers = config.num_hidden_layers
  472. self.embed_dim = config.hidden_size
  473. max_sequence_length = config.max_position_embeddings
  474. self.position_embedding_type = config.pos_emb
  475. self.gradient_checkpointing = False
  476. if self.position_embedding_type == 'learned':
  477. self.wpe = nn.Embedding(max_sequence_length, self.embed_dim)
  478. self.init_method(self.position_embeddings.weight)
  479. self._position_embeddings_key = 'position_embeddings'
  480. self.init_method(self.position_embeddings.weight)
  481. else:
  482. self.wpe = None
  483. self._position_embeddings_key = ''
  484. self.wte = nn.Embedding(self.vocab_size, self.embed_dim)
  485. self.drop = nn.Dropout(config.embd_pdrop)
  486. self.h = nn.ModuleList([
  487. QWenBlock(
  488. config,
  489. layer_idx=i,
  490. ) for i in range(config.num_hidden_layers)
  491. ])
  492. self.ln_f = RMSNorm(
  493. self.embed_dim,
  494. eps=config.layer_norm_epsilon,
  495. )
  496. self.post_init()
  497. def get_input_embeddings(self):
  498. return self.wte
  499. def set_input_embeddings(self, new_embeddings):
  500. self.wte = new_embeddings
  501. def forward(
  502. self,
  503. input_ids: Optional[torch.LongTensor] = None,
  504. past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
  505. attention_mask: Optional[torch.FloatTensor] = None,
  506. token_type_ids: Optional[torch.LongTensor] = None,
  507. position_ids: Optional[torch.LongTensor] = None,
  508. head_mask: Optional[torch.FloatTensor] = None,
  509. inputs_embeds: Optional[torch.FloatTensor] = None,
  510. encoder_hidden_states: Optional[torch.Tensor] = None,
  511. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  512. use_cache: Optional[bool] = None,
  513. output_attentions: Optional[bool] = None,
  514. output_hidden_states: Optional[bool] = None,
  515. return_dict: Optional[bool] = None,
  516. ):
  517. output_attentions = (
  518. output_attentions if output_attentions is not None else
  519. self.config.output_attentions)
  520. output_hidden_states = (
  521. output_hidden_states if output_hidden_states is not None else
  522. self.config.output_hidden_states)
  523. use_cache = use_cache if use_cache is not None else self.config.use_cache
  524. return_dict = (
  525. return_dict
  526. if return_dict is not None else self.config.use_return_dict)
  527. if input_ids is not None and inputs_embeds is not None:
  528. raise ValueError(
  529. 'You cannot specify both input_ids and inputs_embeds at the same time'
  530. )
  531. elif input_ids is not None:
  532. input_shape = input_ids.size()
  533. input_ids = input_ids.view(-1, input_shape[-1])
  534. batch_size = input_ids.shape[0]
  535. elif inputs_embeds is not None:
  536. input_shape = inputs_embeds.size()[:-1]
  537. batch_size = inputs_embeds.shape[0]
  538. else:
  539. raise ValueError(
  540. 'You have to specify either input_ids or inputs_embeds')
  541. device = input_ids.device if input_ids is not None else inputs_embeds.device
  542. if token_type_ids is not None:
  543. token_type_ids = token_type_ids.view(-1, input_shape[-1])
  544. if position_ids is not None:
  545. position_ids = position_ids.view(-1, input_shape[-1])
  546. if past_key_values is None:
  547. past_length = 0
  548. past_key_values = tuple([None] * len(self.h))
  549. else:
  550. past_length = past_key_values[0][0].size(-2)
  551. if position_ids is None:
  552. position_ids = torch.arange(
  553. past_length,
  554. input_shape[-1] + past_length,
  555. dtype=torch.long,
  556. device=device,
  557. )
  558. position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
  559. if attention_mask is not None:
  560. if batch_size <= 0:
  561. raise ValueError('batch_size has to be defined and > 0')
  562. attention_mask = attention_mask.view(batch_size, -1)
  563. attention_mask = attention_mask[:, None, None, :]
  564. attention_mask = attention_mask.to(dtype=self.dtype)
  565. attention_mask = (1.0 - attention_mask) * torch.finfo(
  566. self.dtype).min
  567. encoder_attention_mask = None
  568. head_mask = self.get_head_mask(head_mask, self.config.n_layer)
  569. if inputs_embeds is None:
  570. inputs_embeds = self.wte(input_ids)
  571. hidden_states = inputs_embeds
  572. if self.wpe is not None:
  573. position_embeds = self.wpe(position_ids)
  574. hidden_states = hidden_states + position_embeds
  575. hidden_states = self.drop(hidden_states)
  576. output_shape = input_shape + (hidden_states.size(-1), )
  577. if self.gradient_checkpointing and self.training:
  578. if use_cache:
  579. logger.warning_once(
  580. '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...'
  581. )
  582. use_cache = False
  583. presents = () if use_cache else None
  584. all_self_attentions = () if output_attentions else None
  585. all_hidden_states = () if output_hidden_states else None
  586. for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
  587. if output_hidden_states:
  588. all_hidden_states = all_hidden_states + (hidden_states, )
  589. if self.gradient_checkpointing and self.training:
  590. def create_custom_forward(module):
  591. def custom_forward(*inputs):
  592. return module(*inputs, use_cache, output_attentions)
  593. return custom_forward
  594. outputs = torch.utils.checkpoint.checkpoint(
  595. create_custom_forward(block),
  596. hidden_states,
  597. None,
  598. attention_mask,
  599. head_mask[i],
  600. encoder_hidden_states,
  601. encoder_attention_mask,
  602. )
  603. else:
  604. outputs = block(
  605. hidden_states,
  606. layer_past=layer_past,
  607. attention_mask=attention_mask,
  608. head_mask=head_mask[i],
  609. encoder_hidden_states=encoder_hidden_states,
  610. encoder_attention_mask=encoder_attention_mask,
  611. use_cache=use_cache,
  612. output_attentions=output_attentions,
  613. )
  614. hidden_states = outputs[0]
  615. if use_cache is True:
  616. presents = presents + (
  617. outputs[2 if output_attentions else 1], )
  618. if output_attentions:
  619. all_self_attentions = all_self_attentions + (outputs[1], )
  620. hidden_states = self.ln_f(hidden_states)
  621. hidden_states = hidden_states.view(output_shape)
  622. if not return_dict:
  623. return tuple(v
  624. for v in [hidden_states, presents, all_hidden_states]
  625. if v is not None)
  626. return BaseModelOutputWithPast(
  627. last_hidden_state=hidden_states,
  628. past_key_values=presents,
  629. hidden_states=all_hidden_states,
  630. attentions=all_self_attentions,
  631. )
  632. class RotaryEmbedding(torch.nn.Module):
  633. def __init__(self, dim, base=10000):
  634. super().__init__()
  635. self.dim = dim
  636. self.base = base
  637. self.inv_freq = 1.0 / (base**(torch.arange(0, dim, 2).float() / dim))
  638. if importlib.util.find_spec('einops') is None:
  639. raise RuntimeError('einops is required for Rotary Embedding')
  640. self._rotary_pos_emb_cache = None
  641. self._seq_len_cached = 0
  642. self._ntk_alpha_cached = 1.0
  643. def update_rotary_pos_emb_cache(self,
  644. max_seq_len,
  645. offset=0,
  646. ntk_alpha=1.0):
  647. seqlen = max_seq_len + offset
  648. if seqlen > self._seq_len_cached or ntk_alpha != self._ntk_alpha_cached:
  649. base = self.base * ntk_alpha**(self.dim / (self.dim - 2))
  650. '''
  651. self.inv_freq = 1.0 / (
  652. base**(torch.arange(
  653. 0, self.dim, 2, device=self.inv_freq.device).float()
  654. / self.dim))
  655. '''
  656. self.inv_freq = torch.arange(
  657. 0, self.dim, 2, device=self.inv_freq.device).float() / self.dim
  658. self.inv_freq = 1.0 / (base**self.inv_freq)
  659. self._seq_len_cached = seqlen
  660. self._ntk_alpha_cached = ntk_alpha
  661. seq = torch.arange(seqlen, device=self.inv_freq.device)
  662. freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq)
  663. emb = torch.cat((freqs, freqs), dim=-1)
  664. from einops import rearrange
  665. self._rotary_pos_emb_cache = rearrange(emb, 'n d -> 1 n 1 d')
  666. def forward(self, max_seq_len, offset=0, ntk_alpha=1.0):
  667. self.update_rotary_pos_emb_cache(max_seq_len, offset, ntk_alpha)
  668. return self._rotary_pos_emb_cache[:, offset:offset + max_seq_len]
  669. def _rotate_half(x):
  670. from einops import rearrange
  671. x = rearrange(x, '... (j d) -> ... j d', j=2)
  672. x1, x2 = x.unbind(dim=-2)
  673. return torch.cat((-x2, x1), dim=-1)
  674. def apply_rotary_pos_emb(t, freqs, use_flash_rotary=False):
  675. if use_flash_rotary:
  676. t_ = t.float()
  677. freqs = freqs.squeeze(0).squeeze(1)
  678. cos = freqs[:, :freqs.shape[-1] // 2].cos()
  679. sin = freqs[:, :freqs.shape[-1] // 2].sin()
  680. output = apply_rotary_emb_func(t_, cos, sin).type_as(t)
  681. return output
  682. else:
  683. rot_dim = freqs.shape[-1]
  684. t_, t_pass_ = t[..., :rot_dim], t[..., rot_dim:]
  685. t_ = t_.float()
  686. t_pass_ = t_pass_.float()
  687. t_ = (t_ * freqs.cos()) + (_rotate_half(t_) * freqs.sin())
  688. return torch.cat((t_, t_pass_), dim=-1).type_as(t)
  689. class RMSNorm(torch.nn.Module):
  690. def __init__(self, dim: int, eps: float = 1e-6):
  691. super().__init__()
  692. self.eps = eps
  693. self.weight = nn.Parameter(torch.ones(dim))
  694. def _norm(self, x):
  695. return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
  696. def forward(self, x):
  697. if rms_norm is not None and x.is_cuda:
  698. return rms_norm(x, self.weight, self.eps)
  699. else:
  700. output = self._norm(x.float()).type_as(x)
  701. return output * self.weight