text_generation.py 55 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440
  1. """ PyTorch ChatGLM model. """
  2. import copy
  3. import math
  4. import sys
  5. import warnings
  6. from typing import Any, Callable, Dict, List, Optional, Tuple
  7. import torch
  8. import torch.nn.functional as F
  9. import torch.utils.checkpoint
  10. from torch import nn
  11. from torch.nn import CrossEntropyLoss, LayerNorm
  12. from torch.nn.utils import skip_init
  13. from transformers.generation.logits_process import LogitsProcessor
  14. from transformers.generation.utils import (GenerationConfig,
  15. LogitsProcessorList, ModelOutput,
  16. StoppingCriteriaList)
  17. from transformers.modeling_outputs import (BaseModelOutputWithPast,
  18. CausalLMOutputWithPast)
  19. from transformers.modeling_utils import PreTrainedModel
  20. from modelscope import Model, TorchModel
  21. from modelscope.metainfo import Models
  22. from modelscope.outputs import OutputKeys
  23. from modelscope.utils import logger as logging
  24. from modelscope.utils.constant import Tasks
  25. from ... import MODELS
  26. from .configuration import ChatGLM2Config
  27. # flags required to enable jit fusion kernels
  28. if sys.platform != 'darwin':
  29. torch._C._jit_set_profiling_mode(False)
  30. torch._C._jit_set_profiling_executor(False)
  31. torch._C._jit_override_can_fuse_on_cpu(True)
  32. torch._C._jit_override_can_fuse_on_gpu(True)
  33. logger = logging.get_logger()
  34. _CHECKPOINT_FOR_DOC = 'THUDM/ChatGLM2-6B'
  35. _CONFIG_FOR_DOC = 'ChatGLM6BConfig'
  36. CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [
  37. 'THUDM/chatglm2-6b',
  38. # See all ChatGLM models at https://huggingface.co/models?filter=chatglm
  39. ]
  40. def default_init(cls, *args, **kwargs):
  41. return cls(*args, **kwargs)
  42. class InvalidScoreLogitsProcessor(LogitsProcessor):
  43. def __call__(self, input_ids: torch.LongTensor,
  44. scores: torch.FloatTensor) -> torch.FloatTensor:
  45. if torch.isnan(scores).any() or torch.isinf(scores).any():
  46. scores.zero_()
  47. scores[..., 5] = 5e4
  48. return scores
  49. class PrefixEncoder(torch.nn.Module):
  50. """
  51. The torch.nn model to encode the prefix
  52. Input shape: (batch-size, prefix-length)
  53. Output shape: (batch-size, prefix-length, 2*layers*hidden)
  54. """
  55. def __init__(self, config: ChatGLM2Config):
  56. super().__init__()
  57. self.prefix_projection = config.prefix_projection
  58. if self.prefix_projection:
  59. # Use a two-layer MLP to encode the prefix
  60. kv_size = config.num_layers * config.kv_channels * config.multi_query_group_num * 2
  61. self.embedding = torch.nn.Embedding(config.pre_seq_len, kv_size)
  62. self.trans = torch.nn.Sequential(
  63. torch.nn.Linear(kv_size, config.hidden_size), torch.nn.Tanh(),
  64. torch.nn.Linear(config.hidden_size, kv_size))
  65. else:
  66. self.embedding = torch.nn.Embedding(
  67. config.pre_seq_len, config.num_layers * config.kv_channels
  68. * config.multi_query_group_num * 2)
  69. def forward(self, prefix: torch.Tensor):
  70. if self.prefix_projection:
  71. prefix_tokens = self.embedding(prefix)
  72. past_key_values = self.trans(prefix_tokens)
  73. else:
  74. past_key_values = self.embedding(prefix)
  75. return past_key_values
  76. def split_tensor_along_last_dim(
  77. tensor: torch.Tensor,
  78. num_partitions: int,
  79. contiguous_split_chunks: bool = False,
  80. ) -> List[torch.Tensor]:
  81. """Split a tensor along its last dimension.
  82. Arguments:
  83. tensor: input tensor.
  84. num_partitions: number of partitions to split the tensor
  85. contiguous_split_chunks: If True, make each chunk contiguous
  86. in memory.
  87. Returns:
  88. A list of Tensors
  89. """
  90. # Get the size and dimension.
  91. last_dim = tensor.dim() - 1
  92. last_dim_size = tensor.size()[last_dim] // num_partitions
  93. # Split.
  94. tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
  95. # Note: torch.split does not create contiguous tensors by default.
  96. if contiguous_split_chunks:
  97. return tuple(chunk.contiguous() for chunk in tensor_list)
  98. return tensor_list
  99. class RotaryEmbedding(nn.Module):
  100. def __init__(self,
  101. dim,
  102. rope_ratio=1,
  103. original_impl=False,
  104. device=None,
  105. dtype=None):
  106. super().__init__()
  107. inv_freq = 1.0 / (10000**(
  108. torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim))
  109. self.register_buffer('inv_freq', inv_freq)
  110. self.dim = dim
  111. self.original_impl = original_impl
  112. self.rope_ratio = rope_ratio
  113. def forward_impl(self,
  114. seq_len: int,
  115. n_elem: int,
  116. dtype: torch.dtype,
  117. device: torch.device,
  118. base: int = 10000):
  119. """Enhanced Transformer with Rotary Position Embedding.
  120. Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
  121. transformers/rope/__init__.py. MIT License:
  122. https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
  123. """
  124. # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
  125. theta = 1.0 / (
  126. base**(torch.arange(0, n_elem, 2, dtype=dtype, device=device)
  127. / n_elem))
  128. # Create position indexes `[0, 1, ..., seq_len - 1]`
  129. seq_idx = torch.arange(
  130. seq_len, dtype=dtype, device=device) / self.rope_ratio
  131. # Calculate the product of position index and $\theta_i$
  132. idx_theta = torch.outer(seq_idx, theta).float()
  133. cache = torch.stack(
  134. [torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)
  135. # this is to mimic the behaviour of complex32, else we will get different results
  136. if dtype in (torch.float16, torch.bfloat16, torch.int8):
  137. cache = cache.bfloat16(
  138. ) if dtype == torch.bfloat16 else cache.half()
  139. return cache
  140. def forward(self, max_seq_len, offset=0):
  141. return self.forward_impl(
  142. max_seq_len,
  143. self.dim,
  144. dtype=self.inv_freq.dtype,
  145. device=self.inv_freq.device)
  146. @torch.jit.script
  147. def apply_rotary_pos_emb(x: torch.Tensor,
  148. rope_cache: torch.Tensor) -> torch.Tensor:
  149. # x: [sq, b, np, hn]
  150. sq, _, np, _ = x.size(0), x.size(1), x.size(2), x.size(3)
  151. rot_dim = rope_cache.shape[-2] * 2
  152. x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
  153. # truncate to support variable sizes
  154. rope_cache = rope_cache[:sq]
  155. xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2)
  156. rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2)
  157. x_out2 = torch.stack(
  158. [
  159. xshaped[..., 0] * rope_cache[..., 0]
  160. - xshaped[..., 1] * rope_cache[..., 1],
  161. xshaped[..., 1] * rope_cache[..., 0]
  162. + xshaped[..., 0] * rope_cache[..., 1],
  163. ],
  164. -1,
  165. )
  166. x_out2 = x_out2.flatten(3)
  167. return torch.cat((x_out2, x_pass), dim=-1)
  168. class RMSNorm(torch.nn.Module):
  169. def __init__(self,
  170. normalized_shape,
  171. eps=1e-5,
  172. device=None,
  173. dtype=None,
  174. **kwargs):
  175. super().__init__()
  176. self.weight = torch.nn.Parameter(
  177. torch.empty(normalized_shape, device=device, dtype=dtype))
  178. self.eps = eps
  179. def forward(self, hidden_states: torch.Tensor):
  180. input_dtype = hidden_states.dtype
  181. variance = hidden_states.to(torch.float32).pow(2).mean(
  182. -1, keepdim=True)
  183. hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
  184. return (self.weight * hidden_states).to(input_dtype)
  185. class CoreAttention(torch.nn.Module):
  186. def __init__(self, config: ChatGLM2Config, layer_number):
  187. super(CoreAttention, self).__init__()
  188. self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling
  189. self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
  190. if self.apply_query_key_layer_scaling:
  191. self.attention_softmax_in_fp32 = True
  192. self.layer_number = max(1, layer_number)
  193. projection_size = config.kv_channels * config.num_attention_heads
  194. # Per attention head and per partition values.
  195. self.hidden_size_per_partition = projection_size
  196. self.hidden_size_per_attention_head = projection_size // config.num_attention_heads
  197. self.num_attention_heads_per_partition = config.num_attention_heads
  198. coeff = None
  199. self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
  200. if self.apply_query_key_layer_scaling:
  201. coeff = self.layer_number
  202. self.norm_factor *= coeff
  203. self.coeff = coeff
  204. self.attention_dropout = torch.nn.Dropout(config.attention_dropout)
  205. def forward(self, query_layer, key_layer, value_layer, attention_mask):
  206. pytorch_major_version = int(torch.__version__.split('.')[0])
  207. if pytorch_major_version >= 2:
  208. query_layer, key_layer, value_layer = [
  209. k.permute(1, 2, 0, 3)
  210. for k in [query_layer, key_layer, value_layer]
  211. ]
  212. if attention_mask is None and query_layer.shape[
  213. 2] == key_layer.shape[2]:
  214. context_layer = torch.nn.functional.scaled_dot_product_attention(
  215. query_layer, key_layer, value_layer, is_causal=True)
  216. else:
  217. if attention_mask is not None:
  218. attention_mask = ~attention_mask
  219. context_layer = torch.nn.functional.scaled_dot_product_attention(
  220. query_layer, key_layer, value_layer, attention_mask)
  221. context_layer = context_layer.permute(2, 0, 1, 3)
  222. new_context_layer_shape = context_layer.size()[:-2] + (
  223. self.hidden_size_per_partition, )
  224. context_layer = context_layer.reshape(*new_context_layer_shape)
  225. else:
  226. # Raw attention scores
  227. # [b, np, sq, sk]
  228. output_size = (query_layer.size(1), query_layer.size(2),
  229. query_layer.size(0), key_layer.size(0))
  230. # [sq, b, np, hn] -> [sq, b * np, hn]
  231. query_layer = query_layer.view(output_size[2],
  232. output_size[0] * output_size[1], -1)
  233. # [sk, b, np, hn] -> [sk, b * np, hn]
  234. key_layer = key_layer.view(output_size[3],
  235. output_size[0] * output_size[1], -1)
  236. # preallocting input tensor: [b * np, sq, sk]
  237. matmul_input_buffer = torch.empty(
  238. output_size[0] * output_size[1],
  239. output_size[2],
  240. output_size[3],
  241. dtype=query_layer.dtype,
  242. device=query_layer.device)
  243. # Raw attention scores. [b * np, sq, sk]
  244. matmul_result = torch.baddbmm(
  245. matmul_input_buffer,
  246. query_layer.transpose(0, 1), # [b * np, sq, hn]
  247. key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
  248. beta=0.0,
  249. alpha=(1.0 / self.norm_factor),
  250. )
  251. # change view to [b, np, sq, sk]
  252. attention_scores = matmul_result.view(*output_size)
  253. # ===========================
  254. # Attention probs and dropout
  255. # ===========================
  256. # attention scores and attention mask [b, np, sq, sk]
  257. if self.attention_softmax_in_fp32:
  258. attention_scores = attention_scores.float()
  259. if self.coeff is not None:
  260. attention_scores = attention_scores * self.coeff
  261. if attention_mask is None and attention_scores.shape[
  262. 2] == attention_scores.shape[3]:
  263. attention_mask = torch.ones(
  264. output_size[0],
  265. 1,
  266. output_size[2],
  267. output_size[3],
  268. device=attention_scores.device,
  269. dtype=torch.bool)
  270. attention_mask.tril_()
  271. attention_mask = ~attention_mask
  272. if attention_mask is not None:
  273. attention_scores = attention_scores.masked_fill(
  274. attention_mask, float('-inf'))
  275. attention_probs = F.softmax(attention_scores, dim=-1)
  276. attention_probs = attention_probs.type_as(value_layer)
  277. # This is actually dropping out entire tokens to attend to, which might
  278. # seem a bit unusual, but is taken from the original Transformer paper.
  279. attention_probs = self.attention_dropout(attention_probs)
  280. # =========================
  281. # Context layer. [sq, b, hp]
  282. # =========================
  283. # value_layer -> context layer.
  284. # [sk, b, np, hn] --> [b, np, sq, hn]
  285. # context layer shape: [b, np, sq, hn]
  286. output_size = (value_layer.size(1), value_layer.size(2),
  287. query_layer.size(0), value_layer.size(3))
  288. # change view [sk, b * np, hn]
  289. value_layer = value_layer.view(
  290. value_layer.size(0), output_size[0] * output_size[1], -1)
  291. # change view [b * np, sq, sk]
  292. attention_probs = attention_probs.view(
  293. output_size[0] * output_size[1], output_size[2], -1)
  294. # matmul: [b * np, sq, hn]
  295. context_layer = torch.bmm(attention_probs,
  296. value_layer.transpose(0, 1))
  297. # change view [b, np, sq, hn]
  298. context_layer = context_layer.view(*output_size)
  299. # [b, np, sq, hn] --> [sq, b, np, hn]
  300. context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
  301. # [sq, b, np, hn] --> [sq, b, hp]
  302. new_context_layer_shape = context_layer.size()[:-2] + (
  303. self.hidden_size_per_partition, )
  304. context_layer = context_layer.view(*new_context_layer_shape)
  305. return context_layer
  306. class SelfAttention(torch.nn.Module):
  307. """Parallel self-attention layer abstract class.
  308. Self-attention layer takes input with size [s, b, h]
  309. and returns output of the same size.
  310. """
  311. def __init__(self, config: ChatGLM2Config, layer_number, device=None):
  312. super(SelfAttention, self).__init__()
  313. self.layer_number = max(1, layer_number)
  314. self.projection_size = config.kv_channels * config.num_attention_heads
  315. # Per attention head and per partition values.
  316. self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads
  317. self.num_attention_heads_per_partition = config.num_attention_heads
  318. self.multi_query_attention = config.multi_query_attention
  319. self.qkv_hidden_size = 3 * self.projection_size
  320. if self.multi_query_attention:
  321. self.num_multi_query_groups_per_partition = config.multi_query_group_num
  322. self.qkv_hidden_size = (
  323. self.projection_size + 2 * self.hidden_size_per_attention_head
  324. * config.multi_query_group_num)
  325. self.query_key_value = nn.Linear(
  326. config.hidden_size,
  327. self.qkv_hidden_size,
  328. bias=config.add_bias_linear or config.add_qkv_bias,
  329. device=device,
  330. **_config_to_kwargs(config))
  331. self.core_attention = CoreAttention(config, self.layer_number)
  332. # Output.
  333. self.dense = nn.Linear(
  334. self.projection_size,
  335. config.hidden_size,
  336. bias=config.add_bias_linear,
  337. device=device,
  338. **_config_to_kwargs(config))
  339. def _allocate_memory(self,
  340. inference_max_sequence_len,
  341. batch_size,
  342. device=None,
  343. dtype=None):
  344. if self.multi_query_attention:
  345. num_attention_heads = self.num_multi_query_groups_per_partition
  346. else:
  347. num_attention_heads = self.num_attention_heads_per_partition
  348. return torch.empty(
  349. inference_max_sequence_len,
  350. batch_size,
  351. num_attention_heads,
  352. self.hidden_size_per_attention_head,
  353. dtype=dtype,
  354. device=device,
  355. )
  356. def forward(self,
  357. hidden_states,
  358. attention_mask,
  359. rotary_pos_emb,
  360. kv_cache=None,
  361. use_cache=True):
  362. # hidden_states: [sq, b, h]
  363. # =================================================
  364. # Pre-allocate memory for key-values for inference.
  365. # =================================================
  366. # =====================
  367. # Query, Key, and Value
  368. # =====================
  369. # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
  370. mixed_x_layer = self.query_key_value(hidden_states)
  371. if self.multi_query_attention:
  372. (query_layer, key_layer, value_layer) = mixed_x_layer.split(
  373. [
  374. self.num_attention_heads_per_partition
  375. * self.hidden_size_per_attention_head,
  376. self.num_multi_query_groups_per_partition
  377. * self.hidden_size_per_attention_head,
  378. self.num_multi_query_groups_per_partition
  379. * self.hidden_size_per_attention_head,
  380. ],
  381. dim=-1,
  382. )
  383. query_layer = query_layer.view(query_layer.size()[:-1] + (
  384. self.num_attention_heads_per_partition,
  385. self.hidden_size_per_attention_head))
  386. key_layer = key_layer.view(key_layer.size()[:-1] + (
  387. self.num_multi_query_groups_per_partition,
  388. self.hidden_size_per_attention_head))
  389. value_layer = value_layer.view(value_layer.size()[:-1] + (
  390. self.num_multi_query_groups_per_partition,
  391. self.hidden_size_per_attention_head))
  392. else:
  393. new_tensor_shape = mixed_x_layer.size()[:-1] + \
  394. (self.num_attention_heads_per_partition, # noqa
  395. 3 * self.hidden_size_per_attention_head) # noqa
  396. mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
  397. # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
  398. (query_layer, key_layer,
  399. value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
  400. # apply relative positional encoding (rotary embedding)
  401. if rotary_pos_emb is not None:
  402. query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
  403. key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)
  404. # adjust key and value for inference
  405. if kv_cache is not None:
  406. cache_k, cache_v = kv_cache
  407. key_layer = torch.cat((cache_k, key_layer), dim=0)
  408. value_layer = torch.cat((cache_v, value_layer), dim=0)
  409. if use_cache:
  410. kv_cache = (key_layer, value_layer)
  411. else:
  412. kv_cache = None
  413. if self.multi_query_attention:
  414. key_layer = key_layer.unsqueeze(-2)
  415. key_layer = key_layer.expand(
  416. -1, -1, -1, self.num_attention_heads_per_partition
  417. // self.num_multi_query_groups_per_partition, -1)
  418. key_layer = key_layer.contiguous().view(
  419. key_layer.size()[:2] + (self.num_attention_heads_per_partition,
  420. self.hidden_size_per_attention_head))
  421. value_layer = value_layer.unsqueeze(-2)
  422. value_layer = value_layer.expand(
  423. -1, -1, -1, self.num_attention_heads_per_partition
  424. // self.num_multi_query_groups_per_partition, -1)
  425. value_layer = value_layer.contiguous().view(
  426. value_layer.size()[:2]
  427. + (self.num_attention_heads_per_partition,
  428. self.hidden_size_per_attention_head))
  429. # ==================================
  430. # core attention computation
  431. # ==================================
  432. context_layer = self.core_attention(query_layer, key_layer,
  433. value_layer, attention_mask)
  434. # =================
  435. # Output. [sq, b, h]
  436. # =================
  437. output = self.dense(context_layer)
  438. return output, kv_cache
  439. def _config_to_kwargs(args):
  440. common_kwargs = {
  441. 'dtype': args.torch_dtype,
  442. }
  443. return common_kwargs
  444. class MLP(torch.nn.Module):
  445. """MLP.
  446. MLP will take the input with h hidden state, project it to 4*h
  447. hidden dimension, perform nonlinear transformation, and project the
  448. state back into h hidden dimension.
  449. """
  450. def __init__(self, config: ChatGLM2Config, device=None):
  451. super(MLP, self).__init__()
  452. self.add_bias = config.add_bias_linear
  453. # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
  454. self.dense_h_to_4h = nn.Linear(
  455. config.hidden_size,
  456. config.ffn_hidden_size * 2,
  457. bias=self.add_bias,
  458. device=device,
  459. **_config_to_kwargs(config))
  460. def swiglu(x):
  461. x = torch.chunk(x, 2, dim=-1)
  462. return F.silu(x[0]) * x[1]
  463. self.activation_func = swiglu
  464. # Project back to h.
  465. self.dense_4h_to_h = nn.Linear(
  466. config.ffn_hidden_size,
  467. config.hidden_size,
  468. bias=self.add_bias,
  469. device=device,
  470. **_config_to_kwargs(config))
  471. def forward(self, hidden_states):
  472. # [s, b, 4hp]
  473. intermediate_parallel = self.dense_h_to_4h(hidden_states)
  474. intermediate_parallel = self.activation_func(intermediate_parallel)
  475. # [s, b, h]
  476. output = self.dense_4h_to_h(intermediate_parallel)
  477. return output
  478. class GLMBlock(torch.nn.Module):
  479. """A single transformer layer.
  480. Transformer layer takes input with size [s, b, h] and returns an
  481. output of the same size.
  482. """
  483. def __init__(self, config: ChatGLM2Config, layer_number, device=None):
  484. super(GLMBlock, self).__init__()
  485. self.layer_number = layer_number
  486. self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
  487. self.fp32_residual_connection = config.fp32_residual_connection
  488. LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
  489. # Layernorm on the input data.
  490. self.input_layernorm = LayerNormFunc(
  491. config.hidden_size,
  492. eps=config.layernorm_epsilon,
  493. device=device,
  494. dtype=config.torch_dtype)
  495. # Self attention.
  496. self.self_attention = SelfAttention(
  497. config, layer_number, device=device)
  498. self.hidden_dropout = config.hidden_dropout
  499. # Layernorm on the attention output
  500. self.post_attention_layernorm = LayerNormFunc(
  501. config.hidden_size,
  502. eps=config.layernorm_epsilon,
  503. device=device,
  504. dtype=config.torch_dtype)
  505. # MLP
  506. self.mlp = MLP(config, device=device)
  507. def forward(
  508. self,
  509. hidden_states,
  510. attention_mask,
  511. rotary_pos_emb,
  512. kv_cache=None,
  513. use_cache=True,
  514. ):
  515. # hidden_states: [s, b, h]
  516. # Layer norm at the beginning of the transformer layer.
  517. layernorm_output = self.input_layernorm(hidden_states)
  518. # Self attention.
  519. attention_output, kv_cache = self.self_attention(
  520. layernorm_output,
  521. attention_mask,
  522. rotary_pos_emb,
  523. kv_cache=kv_cache,
  524. use_cache=use_cache)
  525. # Residual connection.
  526. if self.apply_residual_connection_post_layernorm:
  527. residual = layernorm_output
  528. else:
  529. residual = hidden_states
  530. layernorm_input = torch.nn.functional.dropout(
  531. attention_output, p=self.hidden_dropout, training=self.training)
  532. layernorm_input = residual + layernorm_input
  533. # Layer norm post the self attention.
  534. layernorm_output = self.post_attention_layernorm(layernorm_input)
  535. # MLP.
  536. mlp_output = self.mlp(layernorm_output)
  537. # Second residual connection.
  538. if self.apply_residual_connection_post_layernorm:
  539. residual = layernorm_output
  540. else:
  541. residual = layernorm_input
  542. output = torch.nn.functional.dropout(
  543. mlp_output, p=self.hidden_dropout, training=self.training)
  544. output = residual + output
  545. return output, kv_cache
  546. class GLMTransformer(torch.nn.Module):
  547. """Transformer class."""
  548. def __init__(self, config: ChatGLM2Config, device=None):
  549. super(GLMTransformer, self).__init__()
  550. self.fp32_residual_connection = config.fp32_residual_connection
  551. self.post_layer_norm = config.post_layer_norm
  552. # Number of layers.
  553. self.num_layers = config.num_layers
  554. # Transformer layers.
  555. def build_layer(layer_number):
  556. return GLMBlock(config, layer_number, device=device)
  557. self.layers = torch.nn.ModuleList(
  558. [build_layer(i + 1) for i in range(self.num_layers)])
  559. if self.post_layer_norm:
  560. LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
  561. # Final layer norm before output.
  562. self.final_layernorm = LayerNormFunc(
  563. config.hidden_size,
  564. eps=config.layernorm_epsilon,
  565. device=device,
  566. dtype=config.torch_dtype)
  567. self.gradient_checkpointing = False
  568. def _get_layer(self, layer_number):
  569. return self.layers[layer_number]
  570. def forward(
  571. self,
  572. hidden_states,
  573. attention_mask,
  574. rotary_pos_emb,
  575. kv_caches=None,
  576. use_cache: Optional[bool] = True,
  577. output_hidden_states: Optional[bool] = False,
  578. ):
  579. if not kv_caches:
  580. kv_caches = [None for _ in range(self.num_layers)]
  581. presents = () if use_cache else None
  582. if self.gradient_checkpointing and self.training:
  583. if use_cache:
  584. logger.warning_once(
  585. '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...'
  586. )
  587. use_cache = False
  588. all_self_attentions = None
  589. all_hidden_states = () if output_hidden_states else None
  590. for index in range(self.num_layers):
  591. if output_hidden_states:
  592. all_hidden_states = all_hidden_states + (hidden_states, )
  593. layer = self._get_layer(index)
  594. if self.gradient_checkpointing and self.training:
  595. layer_ret = torch.utils.checkpoint.checkpoint(
  596. layer, hidden_states, attention_mask, rotary_pos_emb,
  597. kv_caches[index], use_cache)
  598. else:
  599. layer_ret = layer(
  600. hidden_states,
  601. attention_mask,
  602. rotary_pos_emb,
  603. kv_cache=kv_caches[index],
  604. use_cache=use_cache)
  605. hidden_states, kv_cache = layer_ret
  606. if use_cache:
  607. presents = presents + (kv_cache, )
  608. if output_hidden_states:
  609. all_hidden_states = all_hidden_states + (hidden_states, )
  610. # Final layer norm.
  611. if self.post_layer_norm:
  612. hidden_states = self.final_layernorm(hidden_states)
  613. return hidden_states, presents, all_hidden_states, all_self_attentions
  614. class ChatGLMPreTrainedModel(TorchModel, PreTrainedModel):
  615. """
  616. An abstract class to handle weights initialization and
  617. a simple interface for downloading and loading pretrained models.
  618. """
  619. is_parallelizable = False
  620. supports_gradient_checkpointing = True
  621. config_class = ChatGLM2Config
  622. base_model_prefix = 'transformer'
  623. _no_split_modules = ['GLMBlock']
  624. def __init__(self, config, **kwargs):
  625. super().__init__(config.name_or_path, **kwargs)
  626. super(Model, self).__init__(config)
  627. def _init_weights(self, module: nn.Module):
  628. """Initialize the weights."""
  629. return
  630. def get_masks(self, input_ids, past_key_values, padding_mask=None):
  631. batch_size, seq_length = input_ids.shape
  632. full_attention_mask = torch.ones(
  633. batch_size, seq_length, seq_length, device=input_ids.device)
  634. full_attention_mask.tril_()
  635. past_length = 0
  636. if past_key_values:
  637. past_length = past_key_values[0][0].shape[0]
  638. if past_length:
  639. full_attention_mask = torch.cat(
  640. (
  641. torch.ones( # noqa
  642. batch_size,
  643. seq_length,
  644. past_length, # noqa
  645. device=input_ids.device),
  646. full_attention_mask), # noqa
  647. dim=-1) # noqa
  648. if padding_mask is not None:
  649. full_attention_mask = full_attention_mask * padding_mask.unsqueeze(
  650. 1)
  651. if not past_length and padding_mask is not None:
  652. full_attention_mask -= padding_mask.unsqueeze(-1) - 1
  653. full_attention_mask = (full_attention_mask < 0.5).bool()
  654. full_attention_mask.unsqueeze_(1)
  655. return full_attention_mask
  656. def get_position_ids(self, input_ids, device):
  657. batch_size, seq_length = input_ids.shape
  658. position_ids = torch.arange(
  659. seq_length, dtype=torch.long,
  660. device=device).unsqueeze(0).repeat(batch_size, 1)
  661. return position_ids
  662. def _set_gradient_checkpointing(self, module, value=False):
  663. if isinstance(module, GLMTransformer):
  664. module.gradient_checkpointing = value
  665. @classmethod
  666. def _instantiate(cls, **kwargs):
  667. """Instantiate the model.
  668. Args:
  669. kwargs: Input args.
  670. model_dir: The model dir used to load the checkpoint and the label information.
  671. Returns:
  672. The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained
  673. """
  674. model_dir = kwargs.pop('model_dir', None)
  675. kwargs.pop('cfg', None)
  676. model = super(Model, cls).from_pretrained(
  677. pretrained_model_name_or_path=model_dir, **kwargs)
  678. model.model_dir = model_dir
  679. return model
  680. class Embedding(torch.nn.Module):
  681. """Language model embeddings."""
  682. def __init__(self, config: ChatGLM2Config, device=None):
  683. super(Embedding, self).__init__()
  684. self.hidden_size = config.hidden_size
  685. # Word embeddings (parallel).
  686. self.word_embeddings = nn.Embedding(
  687. config.padded_vocab_size,
  688. self.hidden_size,
  689. dtype=config.torch_dtype,
  690. device=device)
  691. self.fp32_residual_connection = config.fp32_residual_connection
  692. def forward(self, input_ids):
  693. # Embeddings.
  694. words_embeddings = self.word_embeddings(input_ids)
  695. embeddings = words_embeddings
  696. # Data format change to avoid explicit transposes : [b s h] --> [s b h].
  697. embeddings = embeddings.transpose(0, 1).contiguous()
  698. # If the input flag for fp32 residual connection is set, convert for float.
  699. if self.fp32_residual_connection:
  700. embeddings = embeddings.float()
  701. return embeddings
  702. class ChatGLMModel(ChatGLMPreTrainedModel):
  703. def __init__(self, config: ChatGLM2Config, device=None, empty_init=True):
  704. super().__init__(config)
  705. if empty_init:
  706. init_method = skip_init
  707. else:
  708. init_method = default_init
  709. init_kwargs = {}
  710. if device is not None:
  711. init_kwargs['device'] = device
  712. self.embedding = init_method(Embedding, config, **init_kwargs)
  713. self.num_layers = config.num_layers
  714. self.multi_query_group_num = config.multi_query_group_num
  715. self.kv_channels = config.kv_channels
  716. # Rotary positional embeddings
  717. self.seq_length = config.seq_length
  718. rotary_dim = (
  719. config.hidden_size // config.num_attention_heads
  720. if config.kv_channels is None else config.kv_channels)
  721. self.rotary_pos_emb = RotaryEmbedding(
  722. rotary_dim // 2,
  723. rope_ratio=config.rope_ratio,
  724. original_impl=config.original_rope,
  725. device=device,
  726. dtype=config.torch_dtype)
  727. self.encoder = init_method(GLMTransformer, config, **init_kwargs)
  728. self.output_layer = init_method(
  729. nn.Linear,
  730. config.hidden_size,
  731. config.padded_vocab_size,
  732. bias=False,
  733. dtype=config.torch_dtype,
  734. **init_kwargs)
  735. self.pre_seq_len = config.pre_seq_len
  736. self.prefix_projection = config.prefix_projection
  737. if self.pre_seq_len is not None:
  738. for param in self.parameters():
  739. param.requires_grad = False
  740. self.prefix_tokens = torch.arange(self.pre_seq_len).long()
  741. self.prefix_encoder = PrefixEncoder(config)
  742. self.dropout = torch.nn.Dropout(0.1)
  743. def get_input_embeddings(self):
  744. return self.embedding.word_embeddings
  745. def get_prompt(self, batch_size, device, dtype=torch.half):
  746. prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size,
  747. -1).to(device)
  748. past_key_values = self.prefix_encoder(prefix_tokens).type(dtype)
  749. past_key_values = past_key_values.view(batch_size, self.pre_seq_len,
  750. self.num_layers * 2,
  751. self.multi_query_group_num,
  752. self.kv_channels)
  753. # seq_len, b, nh, hidden_size
  754. past_key_values = self.dropout(past_key_values)
  755. past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2)
  756. return past_key_values
  757. def forward(
  758. self,
  759. input_ids,
  760. position_ids: Optional[torch.Tensor] = None,
  761. attention_mask: Optional[torch.BoolTensor] = None,
  762. full_attention_mask: Optional[torch.BoolTensor] = None,
  763. past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor],
  764. ...]] = None,
  765. inputs_embeds: Optional[torch.Tensor] = None,
  766. use_cache: Optional[bool] = None,
  767. output_hidden_states: Optional[bool] = None,
  768. return_dict: Optional[bool] = None,
  769. ):
  770. output_hidden_states = (
  771. output_hidden_states if output_hidden_states is not None else
  772. self.config.output_hidden_states)
  773. use_cache = use_cache if use_cache is not None else self.config.use_cache
  774. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  775. batch_size, seq_length = input_ids.shape
  776. if inputs_embeds is None:
  777. inputs_embeds = self.embedding(input_ids)
  778. if self.pre_seq_len is not None:
  779. if past_key_values is None:
  780. past_key_values = self.get_prompt(
  781. batch_size=batch_size,
  782. device=input_ids.device,
  783. dtype=inputs_embeds.dtype)
  784. if attention_mask is not None:
  785. attention_mask = torch.cat(
  786. [
  787. attention_mask.new_ones( # noqa
  788. (batch_size, self.pre_seq_len)),
  789. attention_mask # noqa
  790. ], # noqa
  791. dim=-1) # noqa
  792. if full_attention_mask is None:
  793. if (attention_mask is not None
  794. and not attention_mask.all()) or (past_key_values
  795. and seq_length != 1):
  796. full_attention_mask = self.get_masks(
  797. input_ids, past_key_values, padding_mask=attention_mask)
  798. # Rotary positional embeddings
  799. rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
  800. if position_ids is not None:
  801. rotary_pos_emb = rotary_pos_emb[position_ids]
  802. else:
  803. rotary_pos_emb = rotary_pos_emb[None, :seq_length]
  804. rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
  805. # Run encoder.
  806. hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
  807. inputs_embeds,
  808. full_attention_mask,
  809. rotary_pos_emb=rotary_pos_emb,
  810. kv_caches=past_key_values,
  811. use_cache=use_cache,
  812. output_hidden_states=output_hidden_states)
  813. if not return_dict:
  814. return tuple(v for v in [
  815. hidden_states, presents, all_hidden_states, all_self_attentions
  816. ] if v is not None)
  817. return BaseModelOutputWithPast(
  818. last_hidden_state=hidden_states,
  819. past_key_values=presents,
  820. hidden_states=all_hidden_states,
  821. attentions=all_self_attentions,
  822. )
  823. def quantize(self, weight_bit_width: int):
  824. from .quantization import quantize
  825. quantize(self.encoder, weight_bit_width)
  826. return self
  827. @MODELS.register_module(Tasks.chat, module_name=Models.chatglm2_6b)
  828. class ChatGLM2ForConditionalGeneration(ChatGLMPreTrainedModel):
  829. def __init__(self, config: ChatGLM2Config, empty_init=True, device=None):
  830. super().__init__(config)
  831. self.max_sequence_length = config.max_length
  832. self.transformer = ChatGLMModel(
  833. config, empty_init=empty_init, device=device)
  834. self.config = config
  835. self.quantized = False
  836. if self.config.quantization_bit:
  837. self.quantize(self.config.quantization_bit, empty_init=True)
  838. def _update_model_kwargs_for_generation(
  839. self,
  840. outputs: ModelOutput,
  841. model_kwargs: Dict[str, Any],
  842. is_encoder_decoder: bool = False,
  843. standardize_cache_format: bool = False,
  844. ) -> Dict[str, Any]:
  845. # update past_key_values
  846. model_kwargs['past_key_values'] = self._extract_past_from_model_output(
  847. outputs, standardize_cache_format=standardize_cache_format)
  848. # update attention mask
  849. if 'attention_mask' in model_kwargs:
  850. attention_mask = model_kwargs['attention_mask']
  851. model_kwargs['attention_mask'] = torch.cat(
  852. [ # noqa
  853. attention_mask, # noqa
  854. attention_mask.new_ones(
  855. (attention_mask.shape[0], 1)) # noqa
  856. ], # noqa
  857. dim=-1) # noqa
  858. # update position ids
  859. if 'position_ids' in model_kwargs:
  860. position_ids = model_kwargs['position_ids']
  861. new_position_id = position_ids[..., -1:].clone()
  862. new_position_id += 1
  863. model_kwargs['position_ids'] = torch.cat(
  864. [position_ids, new_position_id], dim=-1)
  865. model_kwargs['is_first_forward'] = False
  866. return model_kwargs
  867. def prepare_inputs_for_generation(
  868. self,
  869. input_ids: torch.LongTensor,
  870. past_key_values: Optional[torch.Tensor] = None,
  871. attention_mask: Optional[torch.Tensor] = None,
  872. position_ids: Optional[torch.Tensor] = None,
  873. is_first_forward: bool = True,
  874. **kwargs) -> dict:
  875. # only last token for input_ids if past is not None
  876. if position_ids is None:
  877. position_ids = self.get_position_ids(
  878. input_ids, device=input_ids.device)
  879. if not is_first_forward:
  880. position_ids = position_ids[..., -1:]
  881. input_ids = input_ids[:, -1:]
  882. return {
  883. 'input_ids': input_ids,
  884. 'past_key_values': past_key_values,
  885. 'position_ids': position_ids,
  886. 'attention_mask': attention_mask,
  887. 'return_last_logit': True
  888. }
  889. def forward(
  890. self,
  891. input_ids: Optional[torch.Tensor] = None,
  892. position_ids: Optional[torch.Tensor] = None,
  893. attention_mask: Optional[torch.Tensor] = None,
  894. past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
  895. inputs_embeds: Optional[torch.Tensor] = None,
  896. labels: Optional[torch.Tensor] = None,
  897. use_cache: Optional[bool] = None,
  898. output_attentions: Optional[bool] = None,
  899. output_hidden_states: Optional[bool] = None,
  900. return_dict: Optional[bool] = None,
  901. return_last_logit: Optional[bool] = False,
  902. ):
  903. use_cache = use_cache if use_cache is not None else self.config.use_cache
  904. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  905. transformer_outputs = self.transformer(
  906. input_ids=input_ids,
  907. position_ids=position_ids,
  908. attention_mask=attention_mask,
  909. past_key_values=past_key_values,
  910. inputs_embeds=inputs_embeds,
  911. use_cache=use_cache,
  912. output_hidden_states=output_hidden_states,
  913. return_dict=return_dict,
  914. )
  915. hidden_states = transformer_outputs[0]
  916. if return_last_logit:
  917. hidden_states = hidden_states[-1:]
  918. lm_logits = self.transformer.output_layer(hidden_states)
  919. lm_logits = lm_logits.transpose(0, 1).contiguous()
  920. loss = None
  921. if labels is not None:
  922. lm_logits = lm_logits.to(torch.float32)
  923. # Shift so that tokens < n predict n
  924. shift_logits = lm_logits[..., :-1, :].contiguous()
  925. shift_labels = labels[..., 1:].contiguous()
  926. # Flatten the tokens
  927. loss_fct = CrossEntropyLoss(ignore_index=-100)
  928. shift_labels = shift_labels.to(shift_logits.device)
  929. loss = loss_fct(
  930. shift_logits.view(-1, shift_logits.size(-1)),
  931. shift_labels.view(-1))
  932. lm_logits = lm_logits.to(hidden_states.dtype)
  933. loss = loss.to(hidden_states.dtype)
  934. if not return_dict:
  935. output = (lm_logits, ) + transformer_outputs[1:]
  936. return ((loss, ) + output) if loss is not None else output
  937. return CausalLMOutputWithPast(
  938. loss=loss,
  939. logits=lm_logits,
  940. past_key_values=transformer_outputs.past_key_values,
  941. hidden_states=transformer_outputs.hidden_states,
  942. attentions=transformer_outputs.attentions,
  943. )
  944. @staticmethod
  945. def _reorder_cache(
  946. past: Tuple[Tuple[torch.Tensor, torch.Tensor],
  947. ...], beam_idx: torch.LongTensor
  948. ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
  949. """
  950. This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
  951. [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
  952. beam_idx at every generation step.
  953. Output shares the same memory storage as `past`.
  954. """
  955. return tuple((
  956. layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)),
  957. layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)),
  958. ) for layer_past in past)
  959. def process_response(self, response):
  960. response = response.strip()
  961. response = response.replace('[[训练时间]]', '2023年')
  962. return response
  963. def build_inputs(self,
  964. tokenizer,
  965. query: str,
  966. history: List[Tuple[str, str]] = None):
  967. prompt = tokenizer.build_prompt(query, history=history)
  968. inputs = tokenizer([prompt], return_tensors='pt')
  969. inputs = inputs.to(self.device)
  970. return inputs
  971. def build_stream_inputs(self,
  972. tokenizer,
  973. query: str,
  974. history: List[Tuple[str, str]] = None):
  975. if history:
  976. prompt = '\n\n[Round {}]\n\n问:{}\n\n答:'.format(
  977. len(history) + 1, query)
  978. input_ids = tokenizer.encode(prompt, add_special_tokens=False)
  979. input_ids = input_ids[1:]
  980. inputs = tokenizer.batch_encode_plus([(input_ids, None)],
  981. return_tensors='pt',
  982. add_special_tokens=False)
  983. else:
  984. prompt = '[Round {}]\n\n问:{}\n\n答:'.format(len(history) + 1, query)
  985. inputs = tokenizer([prompt], return_tensors='pt')
  986. inputs = inputs.to(self.device)
  987. return inputs
  988. @torch.no_grad()
  989. def _chat(self,
  990. tokenizer,
  991. query: str,
  992. history: List[Tuple[str, str]] = None,
  993. max_length: int = None,
  994. num_beams=1,
  995. do_sample=True,
  996. top_p=0.8,
  997. temperature=0.8,
  998. logits_processor=None,
  999. **kwargs):
  1000. if history is None:
  1001. history = []
  1002. if logits_processor is None:
  1003. logits_processor = LogitsProcessorList()
  1004. logits_processor.append(InvalidScoreLogitsProcessor())
  1005. if max_length is None:
  1006. max_length = self.seq_length
  1007. gen_kwargs = {
  1008. 'max_length': max_length,
  1009. 'num_beams': num_beams,
  1010. 'do_sample': do_sample,
  1011. 'top_p': top_p,
  1012. 'temperature': temperature,
  1013. 'logits_processor': logits_processor,
  1014. **kwargs
  1015. }
  1016. inputs = self.build_inputs(tokenizer, query, history=history)
  1017. outputs = self.generate(**inputs, **gen_kwargs)
  1018. outputs = outputs.tolist()[0][len(inputs['input_ids'][0]):]
  1019. response = tokenizer.decode(outputs)
  1020. response = self.process_response(response)
  1021. history = history + [(query, response)]
  1022. return response, history
  1023. @torch.no_grad()
  1024. def stream_chat(self,
  1025. tokenizer,
  1026. query: str,
  1027. history: List[Tuple[str, str]] = None,
  1028. past_key_values=None,
  1029. max_length: int = None,
  1030. do_sample=True,
  1031. top_p=0.8,
  1032. temperature=0.8,
  1033. logits_processor=None,
  1034. return_past_key_values=False,
  1035. **kwargs):
  1036. if history is None:
  1037. history = []
  1038. if logits_processor is None:
  1039. logits_processor = LogitsProcessorList()
  1040. logits_processor.append(InvalidScoreLogitsProcessor())
  1041. if max_length is None:
  1042. max_length = self.seq_length
  1043. gen_kwargs = {
  1044. 'max_length': max_length,
  1045. 'do_sample': do_sample,
  1046. 'top_p': top_p,
  1047. 'temperature': temperature,
  1048. 'logits_processor': logits_processor,
  1049. **kwargs
  1050. }
  1051. if past_key_values is None and not return_past_key_values:
  1052. inputs = self.build_inputs(tokenizer, query, history=history)
  1053. else:
  1054. inputs = self.build_stream_inputs(
  1055. tokenizer, query, history=history)
  1056. if past_key_values is not None:
  1057. past_length = past_key_values[0][0].shape[0]
  1058. if self.transformer.pre_seq_len is not None:
  1059. past_length -= self.transformer.pre_seq_len
  1060. inputs.position_ids += past_length
  1061. attention_mask = inputs.attention_mask
  1062. attention_mask = torch.cat(
  1063. (attention_mask.new_ones(1, past_length), attention_mask),
  1064. dim=1)
  1065. inputs['attention_mask'] = attention_mask
  1066. for outputs in self.stream_generate(
  1067. **inputs,
  1068. past_key_values=past_key_values,
  1069. return_past_key_values=return_past_key_values,
  1070. **gen_kwargs):
  1071. if return_past_key_values:
  1072. outputs, past_key_values = outputs
  1073. outputs = outputs.tolist()[0][len(inputs['input_ids'][0]):]
  1074. response = tokenizer.decode(outputs)
  1075. if response and response[-1] != '�':
  1076. response = self.process_response(response)
  1077. new_history = history + [(query, response)]
  1078. if return_past_key_values:
  1079. yield response, new_history, past_key_values
  1080. else:
  1081. yield response, new_history
  1082. @torch.no_grad()
  1083. def stream_generate(
  1084. self,
  1085. input_ids,
  1086. generation_config: Optional[GenerationConfig] = None,
  1087. logits_processor: Optional[LogitsProcessorList] = None,
  1088. stopping_criteria: Optional[StoppingCriteriaList] = None,
  1089. prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor],
  1090. List[int]]] = None,
  1091. return_past_key_values=False,
  1092. **kwargs,
  1093. ):
  1094. _, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
  1095. if generation_config is None:
  1096. generation_config = self.generation_config
  1097. generation_config = copy.deepcopy(generation_config)
  1098. model_kwargs = generation_config.update(**kwargs)
  1099. _, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id
  1100. if isinstance(eos_token_id, int):
  1101. eos_token_id = [eos_token_id]
  1102. has_default_max_length = kwargs.get(
  1103. 'max_length') is None and generation_config.max_length is not None
  1104. if has_default_max_length and generation_config.max_new_tokens is None:
  1105. warnings.warn(
  1106. f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
  1107. 'This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we'
  1108. ' recommend using `max_new_tokens` to control the maximum length of the generation.',
  1109. UserWarning,
  1110. )
  1111. elif generation_config.max_new_tokens is not None:
  1112. generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
  1113. if not has_default_max_length:
  1114. logger.warn(
  1115. f'Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(='
  1116. f'{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. '
  1117. 'Please refer to the documentation for more information. '
  1118. '(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)',
  1119. UserWarning,
  1120. )
  1121. if input_ids_seq_length >= generation_config.max_length:
  1122. input_ids_string = 'decoder_input_ids' if self.config.is_encoder_decoder else 'input_ids'
  1123. logger.warning(
  1124. f'Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to'
  1125. f' {generation_config.max_length}. This can lead to unexpected behavior. You should consider'
  1126. ' increasing `max_new_tokens`.')
  1127. # 2. Set generation parameters if not already defined
  1128. logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList(
  1129. )
  1130. stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList(
  1131. )
  1132. logits_processor = self._get_logits_processor(
  1133. generation_config=generation_config,
  1134. input_ids_seq_length=input_ids_seq_length,
  1135. encoder_input_ids=input_ids,
  1136. prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
  1137. logits_processor=logits_processor,
  1138. )
  1139. stopping_criteria = self._get_stopping_criteria(
  1140. generation_config=generation_config,
  1141. stopping_criteria=stopping_criteria)
  1142. logits_warper = self._get_logits_warper(generation_config)
  1143. unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
  1144. scores = None
  1145. while True:
  1146. model_inputs = self.prepare_inputs_for_generation(
  1147. input_ids, **model_kwargs)
  1148. # forward pass to get next token
  1149. outputs = self(
  1150. **model_inputs,
  1151. return_dict=True,
  1152. output_attentions=False,
  1153. output_hidden_states=False,
  1154. )
  1155. next_token_logits = outputs.logits[:, -1, :]
  1156. # pre-process distribution
  1157. next_token_scores = logits_processor(input_ids, next_token_logits)
  1158. next_token_scores = logits_warper(input_ids, next_token_scores)
  1159. # sample
  1160. probs = nn.functional.softmax(next_token_scores, dim=-1)
  1161. if generation_config.do_sample:
  1162. next_tokens = torch.multinomial(
  1163. probs, num_samples=1).squeeze(1)
  1164. else:
  1165. next_tokens = torch.argmax(probs, dim=-1)
  1166. # update generated ids, model inputs, and length for next step
  1167. input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
  1168. model_kwargs = self._update_model_kwargs_for_generation(
  1169. outputs,
  1170. model_kwargs,
  1171. is_encoder_decoder=self.config.is_encoder_decoder)
  1172. unfinished_sequences = unfinished_sequences.mul(
  1173. (sum(next_tokens != i for i in eos_token_id)).long())
  1174. if return_past_key_values:
  1175. yield input_ids, outputs.past_key_values
  1176. else:
  1177. yield input_ids
  1178. # stop when each sentence is finished, or if we exceed the maximum length
  1179. if unfinished_sequences.max() == 0 or stopping_criteria(
  1180. input_ids, scores):
  1181. break
  1182. def quantize(self, bits: int, empty_init=False, device=None, **kwargs):
  1183. if bits == 0:
  1184. return
  1185. from .quantization import quantize
  1186. if self.quantized:
  1187. logger.info('Already quantized.')
  1188. return self
  1189. self.quantized = True
  1190. self.config.quantization_bit = bits
  1191. self.transformer.encoder = quantize(
  1192. self.transformer.encoder,
  1193. bits,
  1194. empty_init=empty_init,
  1195. device=device,
  1196. **kwargs)
  1197. return self
  1198. def chat(self, input: Dict, tokenizer) -> Dict:
  1199. text = input['text']
  1200. history = input['history']
  1201. # args
  1202. if 'max_length' in input:
  1203. max_length = input['max_length']
  1204. else:
  1205. max_length = 2048
  1206. if 'temperature' in input:
  1207. temperature = input['temperature']
  1208. else:
  1209. temperature = 0.95
  1210. if 'num_beams' in input:
  1211. num_beams = input['num_beams']
  1212. else:
  1213. num_beams = 1
  1214. if 'do_sample' in input:
  1215. do_sample = input['do_sample']
  1216. else:
  1217. do_sample = True
  1218. if type(history) == torch.Tensor:
  1219. history = history.tolist()
  1220. response, history = self._chat(
  1221. tokenizer,
  1222. text,
  1223. history,
  1224. max_length=max_length,
  1225. temperature=temperature,
  1226. num_beams=num_beams,
  1227. do_sample=do_sample)
  1228. return {OutputKeys.RESPONSE: response, OutputKeys.HISTORY: history}