text_generation.py 59 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572
  1. """ PyTorch ChatGLM model. """
  2. import copy
  3. import math
  4. import os
  5. import re
  6. import sys
  7. import warnings
  8. from typing import Any, Callable, Dict, List, Optional, Tuple, Union
  9. import torch
  10. import torch.nn.functional as F
  11. import torch.utils.checkpoint
  12. from torch import nn
  13. from torch.nn import CrossEntropyLoss, LayerNorm
  14. from torch.nn.utils import skip_init
  15. from transformers.generation.logits_process import LogitsProcessor
  16. from transformers.generation.utils import (GenerationConfig,
  17. LogitsProcessorList, ModelOutput,
  18. StoppingCriteriaList)
  19. from transformers.modeling_outputs import (
  20. BaseModelOutputWithPast, BaseModelOutputWithPastAndCrossAttentions,
  21. CausalLMOutputWithPast)
  22. from transformers.modeling_utils import PreTrainedModel
  23. from transformers.utils import (add_code_sample_docstrings,
  24. add_start_docstrings,
  25. add_start_docstrings_to_model_forward)
  26. from modelscope.metainfo import Models
  27. from modelscope.models import MODELS, Model, TorchModel
  28. from modelscope.outputs import OutputKeys
  29. from modelscope.utils import logger as logging
  30. from modelscope.utils.constant import Tasks
  31. from .configuration import ChatGLMConfig
  32. from .tokenization import ChatGLMTokenizer
  33. # flags required to enable jit fusion kernels
  34. if sys.platform != 'darwin':
  35. torch._C._jit_set_profiling_mode(False)
  36. torch._C._jit_set_profiling_executor(False)
  37. torch._C._jit_override_can_fuse_on_cpu(True)
  38. torch._C._jit_override_can_fuse_on_gpu(True)
  39. logger = logging.get_logger()
  40. _CHECKPOINT_FOR_DOC = 'THUDM/ChatGLM-6B'
  41. _CONFIG_FOR_DOC = 'ChatGLM6BConfig'
  42. CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [
  43. 'THUDM/chatglm-6b',
  44. # See all ChatGLM-6B models at https://huggingface.co/models?filter=chatglm
  45. ]
  46. class InvalidScoreLogitsProcessor(LogitsProcessor):
  47. def __call__(self, input_ids: torch.LongTensor,
  48. scores: torch.FloatTensor) -> torch.FloatTensor:
  49. if torch.isnan(scores).any() or torch.isinf(scores).any():
  50. scores.zero_()
  51. scores[..., 5] = 5e4
  52. return scores
  53. def load_tf_weights_in_chatglm_6b(model, config, tf_checkpoint_path):
  54. """Load tf checkpoints in a pytorch model."""
  55. try:
  56. import re
  57. import numpy as np
  58. import tensorflow as tf
  59. except ImportError:
  60. logger.error(
  61. 'Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see '
  62. 'https://www.tensorflow.org/install/ for installation instructions.'
  63. )
  64. raise
  65. tf_path = os.path.abspath(tf_checkpoint_path)
  66. logger.info(f'Converting TensorFlow checkpoint from {tf_path}')
  67. # Load weights from TF model
  68. init_vars = tf.train.list_variables(tf_path)
  69. names = []
  70. arrays = []
  71. for name, shape in init_vars:
  72. logger.info(f'Loading TF weight {name} with shape {shape}')
  73. array = tf.train.load_variable(tf_path, name)
  74. names.append(name)
  75. arrays.append(array)
  76. for name, array in zip(names, arrays):
  77. name = name.split('/')
  78. # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
  79. # which are not required for using pretrained model
  80. if any(n in [
  81. 'adam_v', 'adam_m', 'AdamWeightDecayOptimizer',
  82. 'AdamWeightDecayOptimizer_1', 'global_step'
  83. ] for n in name):
  84. logger.info(f"Skipping {'/'.join(name)}")
  85. continue
  86. pointer = model
  87. for m_name in name:
  88. if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
  89. scope_names = re.split(r'_(\d+)', m_name)
  90. else:
  91. scope_names = [m_name]
  92. if scope_names[0] == 'kernel' or scope_names[0] == 'gamma':
  93. pointer = getattr(pointer, 'weight')
  94. elif scope_names[0] == 'output_bias' or scope_names[0] == 'beta':
  95. pointer = getattr(pointer, 'bias')
  96. elif scope_names[0] == 'output_weights':
  97. pointer = getattr(pointer, 'weight')
  98. elif scope_names[0] == 'squad':
  99. pointer = getattr(pointer, 'classifier')
  100. else:
  101. try:
  102. pointer = getattr(pointer, scope_names[0])
  103. except AttributeError:
  104. logger.info(f"Skipping {'/'.join(name)}")
  105. continue
  106. if len(scope_names) >= 2:
  107. num = int(scope_names[1])
  108. pointer = pointer[num]
  109. if m_name[-11:] == '_embeddings':
  110. pointer = getattr(pointer, 'weight')
  111. elif m_name == 'kernel':
  112. array = np.transpose(array)
  113. try:
  114. assert (
  115. pointer.shape == array.shape
  116. ), f'Pointer shape {pointer.shape} and array shape {array.shape} mismatched'
  117. except AssertionError as e:
  118. e.args += (pointer.shape, array.shape)
  119. raise
  120. logger.info(f'Initialize PyTorch weight {name}')
  121. pointer.data = torch.from_numpy(array)
  122. return model
  123. class PrefixEncoder(torch.nn.Module):
  124. """
  125. The torch.nn model to encode the prefix
  126. Input shape: (batch-size, prefix-length)
  127. Output shape: (batch-size, prefix-length, 2*layers*hidden)
  128. """
  129. def __init__(self, config):
  130. super().__init__()
  131. self.prefix_projection = config.prefix_projection
  132. if self.prefix_projection:
  133. # Use a two-layer MLP to encode the prefix
  134. self.embedding = torch.nn.Embedding(config.pre_seq_len,
  135. config.hidden_size)
  136. self.trans = torch.nn.Sequential(
  137. torch.nn.Linear(config.hidden_size, config.hidden_size),
  138. torch.nn.Tanh(),
  139. torch.nn.Linear(config.hidden_size,
  140. config.num_layers * config.hidden_size * 2))
  141. else:
  142. self.embedding = torch.nn.Embedding(
  143. config.pre_seq_len, config.num_layers * config.hidden_size * 2)
  144. def forward(self, prefix: torch.Tensor):
  145. if self.prefix_projection:
  146. prefix_tokens = self.embedding(prefix)
  147. past_key_values = self.trans(prefix_tokens)
  148. else:
  149. past_key_values = self.embedding(prefix)
  150. return past_key_values
  151. @torch.jit.script
  152. def gelu_impl(x):
  153. """OpenAI's gelu implementation."""
  154. return 0.5 * x * (
  155. 1.0 + torch.tanh(0.7978845608028654 * x * # noqa
  156. (1.0 + 0.044715 * x * x))) # noqa
  157. def gelu(x):
  158. return gelu_impl(x)
  159. class RotaryEmbedding(torch.nn.Module):
  160. def __init__(self, dim, base=10000, precision=torch.half, learnable=False):
  161. super().__init__()
  162. inv_freq = 1. / (base**(torch.arange(0, dim, 2).float() / dim))
  163. inv_freq = inv_freq.half()
  164. self.learnable = learnable
  165. if learnable:
  166. self.inv_freq = torch.nn.Parameter(inv_freq)
  167. self.max_seq_len_cached = None
  168. else:
  169. self.register_buffer('inv_freq', inv_freq)
  170. self.max_seq_len_cached = None
  171. self.cos_cached = None
  172. self.sin_cached = None
  173. self.precision = precision
  174. def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
  175. missing_keys, unexpected_keys, error_msgs):
  176. pass
  177. def forward(self, x, seq_dim=1, seq_len=None):
  178. if seq_len is None:
  179. seq_len = x.shape[seq_dim]
  180. if self.max_seq_len_cached is None or (
  181. seq_len > self.max_seq_len_cached): # noqa
  182. self.max_seq_len_cached = None if self.learnable else seq_len
  183. t = torch.arange(
  184. seq_len, device=x.device, dtype=self.inv_freq.dtype)
  185. freqs = torch.einsum('i,j->ij', t, self.inv_freq)
  186. # Different from paper, but it uses a different permutation in order to obtain the same calculation
  187. emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
  188. if self.precision == torch.bfloat16:
  189. emb = emb.float()
  190. # [sx, 1 (b * np), hn]
  191. cos_cached = emb.cos()[:, None, :]
  192. sin_cached = emb.sin()[:, None, :]
  193. if self.precision == torch.bfloat16:
  194. cos_cached = cos_cached.bfloat16()
  195. sin_cached = sin_cached.bfloat16()
  196. if self.learnable:
  197. return cos_cached, sin_cached
  198. self.cos_cached, self.sin_cached = cos_cached, sin_cached
  199. return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...]
  200. def _apply(self, fn):
  201. if self.cos_cached is not None:
  202. self.cos_cached = fn(self.cos_cached)
  203. if self.sin_cached is not None:
  204. self.sin_cached = fn(self.sin_cached)
  205. return super()._apply(fn)
  206. def rotate_half(x):
  207. x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
  208. return torch.cat(
  209. (-x2, x1),
  210. dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions
  211. @torch.jit.script
  212. def apply_rotary_pos_emb_index(q, k, cos, sin, position_id):
  213. # position_id: [sq, b], q, k: [sq, b, np, hn], cos: [sq, 1, hn] -> [sq, b, 1, hn]
  214. cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), \
  215. F.embedding(position_id, sin.squeeze(1)).unsqueeze(2)
  216. q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (
  217. rotate_half(k) * sin)
  218. return q, k
  219. def attention_fn(
  220. self,
  221. query_layer,
  222. key_layer,
  223. value_layer,
  224. attention_mask,
  225. hidden_size_per_partition,
  226. layer_id,
  227. layer_past=None,
  228. scaling_attention_score=True,
  229. use_cache=False,
  230. ):
  231. if layer_past is not None:
  232. past_key, past_value = layer_past[0], layer_past[1]
  233. key_layer = torch.cat((past_key, key_layer), dim=0)
  234. value_layer = torch.cat((past_value, value_layer), dim=0)
  235. # seqlen, batch, num_attention_heads, hidden_size_per_attention_head
  236. seq_len, b, nh, hidden_size = key_layer.shape
  237. if use_cache:
  238. present = (key_layer, value_layer)
  239. else:
  240. present = None
  241. query_key_layer_scaling_coeff = float(layer_id + 1)
  242. if scaling_attention_score:
  243. query_layer = query_layer / (
  244. math.sqrt(hidden_size) * query_key_layer_scaling_coeff)
  245. # ===================================
  246. # Raw attention scores. [b, np, s, s]
  247. # ===================================
  248. # [b, np, sq, sk]
  249. output_size = (query_layer.size(1), query_layer.size(2),
  250. query_layer.size(0), key_layer.size(0))
  251. # [sq, b, np, hn] -> [sq, b * np, hn]
  252. query_layer = query_layer.view(output_size[2],
  253. output_size[0] * output_size[1], -1)
  254. # [sk, b, np, hn] -> [sk, b * np, hn]
  255. key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1],
  256. -1)
  257. matmul_result = torch.zeros(
  258. 1,
  259. 1,
  260. 1,
  261. dtype=query_layer.dtype,
  262. device=query_layer.device,
  263. )
  264. matmul_result = torch.baddbmm(
  265. matmul_result,
  266. query_layer.transpose(0, 1), # [b * np, sq, hn]
  267. key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
  268. beta=0.0,
  269. alpha=1.0,
  270. )
  271. # change view to [b, np, sq, sk]
  272. attention_scores = matmul_result.view(*output_size)
  273. if self.scale_mask_softmax:
  274. self.scale_mask_softmax.scale = query_key_layer_scaling_coeff
  275. attention_probs = self.scale_mask_softmax(attention_scores,
  276. attention_mask.contiguous())
  277. else:
  278. if not (attention_mask == 0).all():
  279. # if auto-regressive, skip
  280. attention_scores.masked_fill_(attention_mask, -10000.0)
  281. dtype = attention_scores.dtype
  282. attention_scores = attention_scores.float()
  283. attention_scores = attention_scores * query_key_layer_scaling_coeff
  284. attention_probs = F.softmax(attention_scores, dim=-1)
  285. attention_probs = attention_probs.type(dtype)
  286. # =========================
  287. # Context layer. [sq, b, hp]
  288. # =========================
  289. # value_layer -> context layer.
  290. # [sk, b, np, hn] --> [b, np, sq, hn]
  291. # context layer shape: [b, np, sq, hn]
  292. output_size = (value_layer.size(1), value_layer.size(2),
  293. query_layer.size(0), value_layer.size(3))
  294. # change view [sk, b * np, hn]
  295. value_layer = value_layer.view(
  296. value_layer.size(0), output_size[0] * output_size[1], -1)
  297. # change view [b * np, sq, sk]
  298. attention_probs = attention_probs.view(output_size[0] * output_size[1],
  299. output_size[2], -1)
  300. # matmul: [b * np, sq, hn]
  301. context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
  302. # change view [b, np, sq, hn]
  303. context_layer = context_layer.view(*output_size)
  304. # [b, np, sq, hn] --> [sq, b, np, hn]
  305. context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
  306. # [sq, b, np, hn] --> [sq, b, hp]
  307. new_context_layer_shape = context_layer.size()[:-2] + (
  308. hidden_size_per_partition, )
  309. context_layer = context_layer.view(*new_context_layer_shape)
  310. outputs = (context_layer, present, attention_probs)
  311. return outputs
  312. class SelfAttention(torch.nn.Module):
  313. def __init__(self,
  314. hidden_size,
  315. num_attention_heads,
  316. layer_id,
  317. hidden_size_per_attention_head=None,
  318. bias=True,
  319. params_dtype=torch.float,
  320. position_encoding_2d=True):
  321. super(SelfAttention, self).__init__()
  322. self.layer_id = layer_id
  323. self.hidden_size = hidden_size
  324. self.hidden_size_per_partition = hidden_size
  325. self.num_attention_heads = num_attention_heads
  326. self.num_attention_heads_per_partition = num_attention_heads
  327. self.position_encoding_2d = position_encoding_2d
  328. self.rotary_emb = RotaryEmbedding( # noqa
  329. self.hidden_size // # noqa
  330. (self.num_attention_heads * 2) if position_encoding_2d else # noqa
  331. self.hidden_size // self.num_attention_heads, # noqa
  332. base=10000, # noqa
  333. precision=torch.half, # noqa
  334. learnable=False, # noqa
  335. ) # noqa
  336. self.scale_mask_softmax = None
  337. if hidden_size_per_attention_head is None:
  338. self.hidden_size_per_attention_head = hidden_size // num_attention_heads
  339. else:
  340. self.hidden_size_per_attention_head = hidden_size_per_attention_head
  341. self.inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head
  342. # Strided linear layer.
  343. self.query_key_value = skip_init(
  344. torch.nn.Linear,
  345. hidden_size,
  346. 3 * self.inner_hidden_size,
  347. bias=bias,
  348. dtype=params_dtype,
  349. )
  350. self.dense = skip_init(
  351. torch.nn.Linear,
  352. self.inner_hidden_size,
  353. hidden_size,
  354. bias=bias,
  355. dtype=params_dtype,
  356. )
  357. @staticmethod
  358. def attention_mask_func(attention_scores, attention_mask):
  359. attention_scores.masked_fill_(attention_mask, -10000.0)
  360. return attention_scores
  361. def split_tensor_along_last_dim(self,
  362. tensor,
  363. num_partitions,
  364. contiguous_split_chunks=False):
  365. """Split a tensor along its last dimension.
  366. Arguments:
  367. tensor: input tensor.
  368. num_partitions: number of partitions to split the tensor
  369. contiguous_split_chunks: If True, make each chunk contiguous
  370. in memory.
  371. """
  372. # Get the size and dimension.
  373. last_dim = tensor.dim() - 1
  374. last_dim_size = tensor.size()[last_dim] // num_partitions
  375. # Split.
  376. tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
  377. # Note: torch.split does not create contiguous tensors by default.
  378. if contiguous_split_chunks:
  379. return tuple(chunk.contiguous() for chunk in tensor_list)
  380. return tensor_list
  381. def forward(
  382. self,
  383. hidden_states: torch.Tensor,
  384. position_ids,
  385. attention_mask: torch.Tensor,
  386. layer_id,
  387. layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
  388. use_cache: bool = False,
  389. output_attentions: bool = False,
  390. ):
  391. """
  392. hidden_states: [seq_len, batch, hidden_size]
  393. attention_mask: [(1, 1), seq_len, seq_len]
  394. """
  395. # [seq_len, batch, 3 * hidden_size]
  396. mixed_raw_layer = self.query_key_value(hidden_states)
  397. # [seq_len, batch, 3 * hidden_size] --> [seq_len, batch, num_attention_heads,
  398. # 3 * hidden_size_per_attention_head]
  399. new_tensor_shape = mixed_raw_layer.size()[:-1] + (
  400. self.num_attention_heads_per_partition,
  401. 3 * self.hidden_size_per_attention_head,
  402. )
  403. mixed_raw_layer = mixed_raw_layer.view(*new_tensor_shape)
  404. # [seq_len, batch, num_attention_heads, hidden_size_per_attention_head]
  405. (query_layer, key_layer,
  406. value_layer) = self.split_tensor_along_last_dim(mixed_raw_layer, 3)
  407. if self.position_encoding_2d:
  408. q1, q2 = query_layer.chunk(2, dim=(query_layer.ndim - 1))
  409. k1, k2 = key_layer.chunk(2, dim=(key_layer.ndim - 1))
  410. cos, sin = self.rotary_emb(q1, seq_len=position_ids.max() + 1)
  411. position_ids, block_position_ids = position_ids[:, 0, :].transpose(0, 1).contiguous(), \
  412. position_ids[:, 1, :].transpose(0, 1).contiguous()
  413. q1, k1 = apply_rotary_pos_emb_index(q1, k1, cos, sin, position_ids)
  414. q2, k2 = apply_rotary_pos_emb_index(q2, k2, cos, sin,
  415. block_position_ids)
  416. query_layer = torch.concat([q1, q2], dim=(q1.ndim - 1))
  417. key_layer = torch.concat([k1, k2], dim=(k1.ndim - 1))
  418. else:
  419. position_ids = position_ids.transpose(0, 1)
  420. cos, sin = self.rotary_emb(
  421. value_layer, seq_len=position_ids.max() + 1)
  422. # [seq_len, batch, num_attention_heads, hidden_size_per_attention_head]
  423. query_layer, key_layer = apply_rotary_pos_emb_index(
  424. query_layer, key_layer, cos, sin, position_ids)
  425. # [seq_len, batch, hidden_size]
  426. context_layer, present, attention_probs = attention_fn(
  427. self=self,
  428. query_layer=query_layer,
  429. key_layer=key_layer,
  430. value_layer=value_layer,
  431. attention_mask=attention_mask,
  432. hidden_size_per_partition=self.hidden_size_per_partition,
  433. layer_id=layer_id,
  434. layer_past=layer_past,
  435. use_cache=use_cache)
  436. output = self.dense(context_layer)
  437. outputs = (output, present)
  438. if output_attentions:
  439. outputs += (attention_probs, )
  440. return outputs # output, present, attention_probs
  441. class GEGLU(torch.nn.Module):
  442. def __init__(self):
  443. super().__init__()
  444. self.activation_fn = F.gelu
  445. def forward(self, x):
  446. # dim=-1 breaks in jit for pt<1.10
  447. x1, x2 = x.chunk(2, dim=(x.ndim - 1))
  448. return x1 * self.activation_fn(x2)
  449. class GLU(torch.nn.Module):
  450. def __init__(self,
  451. hidden_size,
  452. inner_hidden_size=None,
  453. layer_id=None,
  454. bias=True,
  455. activation_func=gelu,
  456. params_dtype=torch.float):
  457. super(GLU, self).__init__()
  458. self.layer_id = layer_id
  459. self.activation_func = activation_func
  460. # Project to 4h.
  461. self.hidden_size = hidden_size
  462. if inner_hidden_size is None:
  463. inner_hidden_size = 4 * hidden_size
  464. self.inner_hidden_size = inner_hidden_size
  465. self.dense_h_to_4h = skip_init(
  466. torch.nn.Linear,
  467. self.hidden_size,
  468. self.inner_hidden_size,
  469. bias=bias,
  470. dtype=params_dtype,
  471. )
  472. # Project back to h.
  473. self.dense_4h_to_h = skip_init(
  474. torch.nn.Linear,
  475. self.inner_hidden_size,
  476. self.hidden_size,
  477. bias=bias,
  478. dtype=params_dtype,
  479. )
  480. def forward(self, hidden_states):
  481. """
  482. hidden_states: [seq_len, batch, hidden_size]
  483. """
  484. # [seq_len, batch, inner_hidden_size]
  485. intermediate_parallel = self.dense_h_to_4h(hidden_states)
  486. intermediate_parallel = self.activation_func(intermediate_parallel)
  487. output = self.dense_4h_to_h(intermediate_parallel)
  488. return output
  489. class GLMBlock(torch.nn.Module):
  490. def __init__(self,
  491. hidden_size,
  492. num_attention_heads,
  493. layernorm_epsilon,
  494. layer_id,
  495. inner_hidden_size=None,
  496. hidden_size_per_attention_head=None,
  497. layernorm=LayerNorm,
  498. use_bias=True,
  499. params_dtype=torch.float,
  500. num_layers=28,
  501. position_encoding_2d=True):
  502. super(GLMBlock, self).__init__()
  503. # Set output layer initialization if not provided.
  504. self.layer_id = layer_id
  505. # Layernorm on the input data.
  506. self.input_layernorm = layernorm(hidden_size, eps=layernorm_epsilon)
  507. self.position_encoding_2d = position_encoding_2d
  508. # Self attention.
  509. self.attention = SelfAttention(
  510. hidden_size,
  511. num_attention_heads,
  512. layer_id,
  513. hidden_size_per_attention_head=hidden_size_per_attention_head,
  514. bias=use_bias,
  515. params_dtype=params_dtype,
  516. position_encoding_2d=self.position_encoding_2d)
  517. # Layernorm on the input data.
  518. self.post_attention_layernorm = layernorm(
  519. hidden_size, eps=layernorm_epsilon)
  520. self.num_layers = num_layers
  521. # GLU
  522. self.mlp = GLU(
  523. hidden_size,
  524. inner_hidden_size=inner_hidden_size,
  525. bias=use_bias,
  526. layer_id=layer_id,
  527. params_dtype=params_dtype,
  528. )
  529. def forward(
  530. self,
  531. hidden_states: torch.Tensor,
  532. position_ids,
  533. attention_mask: torch.Tensor,
  534. layer_id,
  535. layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
  536. use_cache: bool = False,
  537. output_attentions: bool = False,
  538. ):
  539. """
  540. hidden_states: [seq_len, batch, hidden_size]
  541. attention_mask: [(1, 1), seq_len, seq_len]
  542. """
  543. # Layer norm at the beginning of the transformer layer.
  544. # [seq_len, batch, hidden_size]
  545. attention_input = self.input_layernorm(hidden_states)
  546. # Self attention.
  547. attention_outputs = self.attention(
  548. attention_input,
  549. position_ids,
  550. attention_mask=attention_mask,
  551. layer_id=layer_id,
  552. layer_past=layer_past,
  553. use_cache=use_cache,
  554. output_attentions=output_attentions)
  555. attention_output = attention_outputs[0]
  556. outputs = attention_outputs[1:]
  557. # Residual connection.
  558. alpha = (2 * self.num_layers)**0.5
  559. hidden_states = attention_input * alpha + attention_output
  560. mlp_input = self.post_attention_layernorm(hidden_states)
  561. # MLP.
  562. mlp_output = self.mlp(mlp_input)
  563. # Second residual connection.
  564. output = mlp_input * alpha + mlp_output
  565. if use_cache:
  566. outputs = (output, ) + outputs
  567. else:
  568. outputs = (output, ) + outputs[1:]
  569. return outputs # hidden_states, present, attentions
  570. class ChatGLMPreTrainedModel(TorchModel, PreTrainedModel):
  571. """
  572. An abstract class to handle weights initialization and
  573. a simple interface for downloading and loading pretrained models.
  574. """
  575. is_parallelizable = False
  576. supports_gradient_checkpointing = True
  577. config_class = ChatGLMConfig
  578. base_model_prefix = 'transformer'
  579. _no_split_modules = ['GLMBlock']
  580. def __init__(self, config, **kwargs):
  581. super().__init__(config.name_or_path, **kwargs)
  582. super(Model, self).__init__(config)
  583. def _init_weights(self, module: nn.Module):
  584. """Initialize the weights."""
  585. return
  586. def get_masks(self, input_ids, device):
  587. batch_size, seq_length = input_ids.shape
  588. context_lengths = [
  589. seq.tolist().index(self.config.bos_token_id) for seq in input_ids
  590. ]
  591. attention_mask = torch.ones((batch_size, seq_length, seq_length),
  592. device=device)
  593. attention_mask.tril_()
  594. for i, context_length in enumerate(context_lengths):
  595. attention_mask[i, :, :context_length] = 1
  596. attention_mask.unsqueeze_(1)
  597. attention_mask = (attention_mask < 0.5).bool()
  598. return attention_mask
  599. def get_position_ids(self, input_ids, mask_positions, device, gmask=False):
  600. batch_size, seq_length = input_ids.shape
  601. context_lengths = [
  602. seq.tolist().index(self.config.bos_token_id) for seq in input_ids
  603. ]
  604. if self.position_encoding_2d:
  605. position_ids = torch.arange(
  606. seq_length, dtype=torch.long,
  607. device=device).unsqueeze(0).repeat(batch_size, 1)
  608. for i, context_length in enumerate(context_lengths):
  609. position_ids[i, context_length:] = mask_positions[i]
  610. block_position_ids = [
  611. torch.cat((
  612. torch.zeros( # noqa
  613. context_length,
  614. dtype=torch.long,
  615. device=device), # noqa
  616. torch.arange( # noqa
  617. seq_length - context_length, # noqa
  618. dtype=torch.long, # noqa
  619. device=device) + 1)) # noqa
  620. for context_length in context_lengths
  621. ]
  622. block_position_ids = torch.stack(block_position_ids, dim=0)
  623. position_ids = torch.stack((position_ids, block_position_ids),
  624. dim=1)
  625. else:
  626. position_ids = torch.arange(
  627. seq_length, dtype=torch.long,
  628. device=device).unsqueeze(0).repeat(batch_size, 1)
  629. if not gmask:
  630. for i, context_length in enumerate(context_lengths):
  631. position_ids[context_length:] = mask_positions[i]
  632. return position_ids
  633. def _set_gradient_checkpointing(self, module, value=False):
  634. if isinstance(module, ChatGLMModel):
  635. module.gradient_checkpointing = value
  636. @classmethod
  637. def _instantiate(cls, **kwargs):
  638. """Instantiate the model.
  639. Args:
  640. kwargs: Input args.
  641. model_dir: The model dir used to load the checkpoint and the label information.
  642. Returns:
  643. The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained
  644. """
  645. model_dir = kwargs.pop('model_dir', None)
  646. kwargs.pop('cfg', None)
  647. model = super(Model, cls).from_pretrained(
  648. pretrained_model_name_or_path=model_dir, **kwargs)
  649. model.model_dir = model_dir
  650. return model
  651. CHATGLM_6B_START_DOCSTRING = r"""
  652. This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class.
  653. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
  654. usage and behavior.
  655. Parameters:
  656. config ([`~ChatGLM6BConfig`]): Model configuration class with all the parameters of the model.
  657. Initializing with a config file does not load the weights associated with the model, only the configuration.
  658. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  659. """
  660. CHATGLM_6B_INPUTS_DOCSTRING = r"""
  661. Args:
  662. input_ids (`torch.LongTensor` of shape `({0})`):
  663. Indices of input sequence tokens in the vocabulary.
  664. Indices can be obtained using [`ChatGLM6BTokenizer`].
  665. See [`PreTrainedTokenizer.encode`] and
  666. [`PreTrainedTokenizer.__call__`] for details.
  667. [What are input IDs?](../glossary#input-ids)
  668. attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
  669. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  670. - 1 for tokens that are **not masked**,
  671. - 0 for tokens that are **masked**.
  672. [What are attention masks?](../glossary#attention-mask)
  673. token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
  674. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, 1]`:
  675. - 0 corresponds to a *sentence A* token,
  676. - 1 corresponds to a *sentence B* token.
  677. [What are token type IDs?](../glossary#token-type-ids)
  678. position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
  679. Indices of positions of each input sequence tokens in the position embeddings.
  680. Selected in the range `[0, config.max_position_embeddings - 1]`.
  681. [What are position IDs?](../glossary#position-ids)
  682. head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  683. Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
  684. - 1 indicates the head is **not masked**,
  685. - 0 indicates the head is **masked**.
  686. inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
  687. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
  688. This is useful if you want more control over how to convert *input_ids* indices into associated vectors
  689. than the model's internal embedding lookup matrix.
  690. output_attentions (`bool`, *optional*):
  691. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  692. tensors for more detail.
  693. output_hidden_states (`bool`, *optional*):
  694. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  695. more detail.
  696. return_dict (`bool`, *optional*):
  697. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  698. """
  699. @add_start_docstrings(
  700. 'The bare ChatGLM-6B Model transformer outputting raw hidden-states without any specific head on top.',
  701. CHATGLM_6B_START_DOCSTRING,
  702. )
  703. class ChatGLMModel(ChatGLMPreTrainedModel):
  704. """
  705. The model can behave as an encoder (with only self-attention) as well
  706. as a decoder, in which case a layer of cross-attention is added between
  707. the self-attention layers, following the architecture described in [Attention is
  708. all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani,
  709. Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
  710. To behave as an decoder the model needs to be initialized with the
  711. `is_decoder` argument of the configuration set to `True`.
  712. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder`
  713. argument and `add_cross_attention` set to `True`; an
  714. `encoder_hidden_states` is then expected as an input to the forward pass.
  715. """
  716. def __init__(self, config: ChatGLMConfig):
  717. super().__init__(config)
  718. # recording parameters
  719. self.max_sequence_length = config.max_sequence_length
  720. self.hidden_size = config.hidden_size
  721. self.params_dtype = torch.half
  722. self.num_attention_heads = config.num_attention_heads
  723. self.vocab_size = config.vocab_size
  724. self.num_layers = config.num_layers
  725. self.layernorm_epsilon = config.layernorm_epsilon
  726. self.inner_hidden_size = config.inner_hidden_size
  727. self.hidden_size_per_attention_head = self.hidden_size // self.num_attention_heads
  728. self.position_encoding_2d = config.position_encoding_2d
  729. self.pre_seq_len = config.pre_seq_len
  730. self.prefix_projection = config.prefix_projection
  731. self.word_embeddings = skip_init(
  732. torch.nn.Embedding,
  733. num_embeddings=self.vocab_size,
  734. embedding_dim=self.hidden_size,
  735. dtype=self.params_dtype)
  736. self.gradient_checkpointing = False
  737. def get_layer(layer_id):
  738. return GLMBlock(
  739. self.hidden_size,
  740. self.num_attention_heads,
  741. self.layernorm_epsilon,
  742. layer_id,
  743. inner_hidden_size=self.inner_hidden_size,
  744. hidden_size_per_attention_head=self.
  745. hidden_size_per_attention_head,
  746. layernorm=LayerNorm,
  747. use_bias=True,
  748. params_dtype=self.params_dtype,
  749. position_encoding_2d=self.position_encoding_2d,
  750. )
  751. self.layers = torch.nn.ModuleList(
  752. [get_layer(layer_id) for layer_id in range(self.num_layers)])
  753. # Final layer norm before output.
  754. self.final_layernorm = LayerNorm(
  755. self.hidden_size, eps=self.layernorm_epsilon)
  756. if self.pre_seq_len is not None:
  757. for param in self.parameters():
  758. param.requires_grad = False
  759. self.prefix_tokens = torch.arange(self.pre_seq_len).long()
  760. self.prefix_encoder = PrefixEncoder(config)
  761. self.dropout = torch.nn.Dropout(0.1)
  762. # total_params = sum(p.numel() for p in self.parameters())
  763. # trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
  764. # print("Using p-tuning v2: # trainable_params = {} / {}".format(trainable_params, total_params))
  765. def get_input_embeddings(self):
  766. return self.word_embeddings
  767. def set_input_embeddings(self, new_embeddings: torch.Tensor):
  768. self.word_embeddings = new_embeddings
  769. def get_prompt(self, batch_size, device, dtype=torch.half):
  770. prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size,
  771. -1).to(device)
  772. past_key_values = self.prefix_encoder(prefix_tokens).type(dtype)
  773. past_key_values = past_key_values.view(
  774. batch_size, self.pre_seq_len, self.num_layers * 2,
  775. self.num_attention_heads,
  776. self.hidden_size // self.num_attention_heads)
  777. # seq_len, b, nh, hidden_size
  778. past_key_values = self.dropout(past_key_values)
  779. past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2)
  780. # past_key_values = [(v[0], v[1]) for v in past_key_values]
  781. return past_key_values
  782. @add_start_docstrings_to_model_forward(
  783. CHATGLM_6B_INPUTS_DOCSTRING.format('batch_size, sequence_length'))
  784. @add_code_sample_docstrings(
  785. checkpoint=_CHECKPOINT_FOR_DOC,
  786. output_type=BaseModelOutputWithPastAndCrossAttentions,
  787. config_class=_CONFIG_FOR_DOC,
  788. )
  789. def forward(
  790. self,
  791. input_ids: Optional[torch.LongTensor] = None,
  792. position_ids: Optional[torch.LongTensor] = None,
  793. attention_mask: Optional[torch.Tensor] = None,
  794. past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor],
  795. ...]] = None,
  796. inputs_embeds: Optional[torch.LongTensor] = None,
  797. use_cache: Optional[bool] = None,
  798. output_attentions: Optional[bool] = None,
  799. output_hidden_states: Optional[bool] = None,
  800. return_dict: Optional[bool] = None,
  801. ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPast]:
  802. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  803. output_hidden_states = (
  804. output_hidden_states if output_hidden_states is not None else
  805. self.config.output_hidden_states)
  806. use_cache = use_cache if use_cache is not None else self.config.use_cache
  807. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  808. if self.gradient_checkpointing and self.training:
  809. if use_cache:
  810. # logger.warning_once(
  811. # "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  812. # )
  813. use_cache = False
  814. if input_ids is not None and inputs_embeds is not None:
  815. raise ValueError(
  816. 'You cannot specify both input_ids and inputs_embeds at the same time'
  817. )
  818. elif input_ids is not None:
  819. batch_size, seq_length = input_ids.shape[:2]
  820. elif inputs_embeds is not None:
  821. batch_size, seq_length, _ = inputs_embeds.shape[:2]
  822. else:
  823. raise ValueError(
  824. 'You have to specify either input_ids or inputs_embeds')
  825. if inputs_embeds is None:
  826. inputs_embeds = self.word_embeddings(input_ids)
  827. if past_key_values is None:
  828. if self.pre_seq_len is not None:
  829. past_key_values = self.get_prompt(
  830. batch_size=input_ids.shape[0],
  831. device=input_ids.device,
  832. dtype=inputs_embeds.dtype)
  833. else:
  834. past_key_values = tuple([None] * len(self.layers))
  835. if attention_mask is None:
  836. attention_mask = self.get_masks(
  837. input_ids, device=input_ids.device)
  838. if position_ids is None:
  839. MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
  840. mask_token = gMASK if gMASK in input_ids else MASK
  841. use_gmask = True if gMASK in input_ids else False
  842. mask_positions = [
  843. seq.tolist().index(mask_token) for seq in input_ids
  844. ]
  845. position_ids = self.get_position_ids(
  846. input_ids,
  847. mask_positions=mask_positions,
  848. device=input_ids.device,
  849. gmask=use_gmask)
  850. if self.pre_seq_len is not None and attention_mask is not None:
  851. prefix_attention_mask = torch.ones(
  852. batch_size, 1, input_ids.size(-1),
  853. self.pre_seq_len).to(attention_mask.device)
  854. prefix_attention_mask = (prefix_attention_mask < 0.5).bool()
  855. attention_mask = torch.cat((prefix_attention_mask, attention_mask),
  856. dim=3)
  857. # [seq_len, batch, hidden_size]
  858. hidden_states = inputs_embeds.transpose(0, 1)
  859. presents = () if use_cache else None
  860. all_self_attentions = () if output_attentions else None
  861. all_hidden_states = () if output_hidden_states else None
  862. if attention_mask is None:
  863. attention_mask = torch.zeros(1, 1, device=input_ids.device).bool()
  864. else:
  865. attention_mask = attention_mask.to(input_ids.device)
  866. for i, layer in enumerate(self.layers):
  867. if output_hidden_states:
  868. all_hidden_states = all_hidden_states + (hidden_states, )
  869. layer_past = past_key_values[i]
  870. if self.gradient_checkpointing and self.training:
  871. layer_ret = torch.utils.checkpoint.checkpoint(
  872. layer, hidden_states, position_ids, attention_mask,
  873. torch.tensor(i), layer_past, use_cache, output_attentions)
  874. else:
  875. layer_ret = layer(
  876. hidden_states,
  877. position_ids=position_ids,
  878. attention_mask=attention_mask,
  879. layer_id=torch.tensor(i),
  880. layer_past=layer_past,
  881. use_cache=use_cache,
  882. output_attentions=output_attentions)
  883. hidden_states = layer_ret[0]
  884. if use_cache:
  885. presents = presents + (layer_ret[1], )
  886. if output_attentions:
  887. all_self_attentions = all_self_attentions + (
  888. layer_ret[2 if use_cache else 1], )
  889. # Final layer norm.
  890. hidden_states = self.final_layernorm(hidden_states)
  891. if output_hidden_states:
  892. all_hidden_states = all_hidden_states + (hidden_states, )
  893. if not return_dict:
  894. return tuple(v for v in [
  895. hidden_states, presents, all_hidden_states, all_self_attentions
  896. ] if v is not None)
  897. return BaseModelOutputWithPast(
  898. last_hidden_state=hidden_states,
  899. past_key_values=presents,
  900. hidden_states=all_hidden_states,
  901. attentions=all_self_attentions,
  902. )
  903. @MODELS.register_module(Tasks.chat, module_name=Models.chatglm_6b)
  904. class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
  905. def __init__(self, config: ChatGLMConfig):
  906. super().__init__(config)
  907. # self.hidden_size = config.hidden_size
  908. # self.params_dtype = torch.half
  909. # self.vocab_size = config.vocab_size
  910. self.max_sequence_length = config.max_sequence_length
  911. self.position_encoding_2d = config.position_encoding_2d
  912. self.transformer = ChatGLMModel(config)
  913. self.lm_head = skip_init(
  914. nn.Linear,
  915. config.hidden_size,
  916. config.vocab_size,
  917. bias=False,
  918. dtype=torch.half)
  919. self.config = config
  920. self.quantized = False
  921. if self.config.quantization_bit:
  922. self.quantize(self.config.quantization_bit, empty_init=True)
  923. # loading tokenizer
  924. self.tokenizer = ChatGLMTokenizer.from_pretrained(config.name_or_path)
  925. def get_output_embeddings(self):
  926. return self.lm_head
  927. def set_output_embeddings(self, new_embeddings):
  928. self.lm_head = new_embeddings
  929. def _update_model_kwargs_for_generation(
  930. self,
  931. outputs: ModelOutput,
  932. model_kwargs: Dict[str, Any],
  933. is_encoder_decoder: bool = False,
  934. standardize_cache_format: bool = False,
  935. ) -> Dict[str, Any]:
  936. # update past_key_values
  937. model_kwargs['past_key_values'] = self._extract_past_from_model_output(
  938. outputs, standardize_cache_format=standardize_cache_format)
  939. # update attention mask
  940. if 'attention_mask' in model_kwargs:
  941. attention_mask = model_kwargs['attention_mask']
  942. if attention_mask is not None and attention_mask.dtype == torch.bool:
  943. attention_mask = torch.cat([
  944. attention_mask,
  945. attention_mask.new_ones((*attention_mask.shape[:3], 1))
  946. ],
  947. dim=3) # noqa
  948. new_attention_mask = attention_mask[:, :, -1:].clone()
  949. new_attention_mask[..., -1] = False
  950. model_kwargs['attention_mask'] = torch.cat(
  951. [attention_mask, new_attention_mask], dim=2)
  952. # update position ids
  953. if 'position_ids' in model_kwargs:
  954. position_ids = model_kwargs['position_ids']
  955. new_position_id = position_ids[..., -1:].clone()
  956. new_position_id[:, 1, :] += 1
  957. model_kwargs['position_ids'] = torch.cat(
  958. [position_ids, new_position_id], dim=-1)
  959. return model_kwargs
  960. def prepare_inputs_for_generation(
  961. self,
  962. input_ids: torch.LongTensor,
  963. past: Optional[torch.Tensor] = None,
  964. past_key_values: Optional[torch.Tensor] = None,
  965. attention_mask: Optional[torch.Tensor] = None,
  966. position_ids: Optional[torch.Tensor] = None,
  967. **kwargs) -> dict:
  968. batch_size, seq_length = input_ids.shape
  969. MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
  970. mask_token = gMASK if gMASK in input_ids else MASK
  971. use_gmask = True if gMASK in input_ids else False
  972. seqs = input_ids.tolist()
  973. mask_positions = [seq.index(mask_token) for seq in seqs]
  974. # only last token for input_ids if past is not None
  975. if past is not None or past_key_values is not None:
  976. last_token = input_ids[:, -1].unsqueeze(-1)
  977. if attention_mask is not None and attention_mask.dtype == torch.bool:
  978. attention_mask = attention_mask[:, :, -1:]
  979. else:
  980. attention_mask = None
  981. if position_ids is not None:
  982. position_ids = position_ids[..., -1:]
  983. else:
  984. context_lengths = [
  985. seq.index(self.config.bos_token_id) for seq in seqs
  986. ]
  987. if self.position_encoding_2d:
  988. position_ids = torch.tensor(
  989. [[mask_position, seq_length - context_length]
  990. for mask_position, context_length in zip(
  991. mask_positions, context_lengths)],
  992. dtype=torch.long,
  993. device=input_ids.device).unsqueeze(-1)
  994. else:
  995. position_ids = torch.tensor(
  996. [mask_position for mask_position in mask_positions],
  997. dtype=torch.long,
  998. device=input_ids.device).unsqueeze(-1)
  999. if past is None:
  1000. past = past_key_values
  1001. return {
  1002. 'input_ids': last_token,
  1003. 'past_key_values': past,
  1004. 'position_ids': position_ids,
  1005. 'attention_mask': attention_mask
  1006. }
  1007. else:
  1008. if attention_mask is not None and attention_mask.dtype != torch.bool:
  1009. # logger.warning_once(f"The dtype of attention mask ({attention_mask.dtype}) is not bool")
  1010. attention_mask = None
  1011. if attention_mask is None:
  1012. attention_mask = self.get_masks(
  1013. input_ids, device=input_ids.device)
  1014. if position_ids is None:
  1015. position_ids = self.get_position_ids(
  1016. input_ids,
  1017. device=input_ids.device,
  1018. mask_positions=mask_positions,
  1019. gmask=use_gmask)
  1020. return {
  1021. 'input_ids': input_ids,
  1022. 'past_key_values': past,
  1023. 'position_ids': position_ids,
  1024. 'attention_mask': attention_mask
  1025. }
  1026. def forward(
  1027. self,
  1028. input_ids: Optional[torch.Tensor] = None,
  1029. position_ids: Optional[torch.Tensor] = None,
  1030. attention_mask: Optional[torch.Tensor] = None,
  1031. past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
  1032. inputs_embeds: Optional[torch.Tensor] = None,
  1033. labels: Optional[torch.Tensor] = None,
  1034. use_cache: Optional[bool] = None,
  1035. output_attentions: Optional[bool] = None,
  1036. output_hidden_states: Optional[bool] = None,
  1037. return_dict: Optional[bool] = None,
  1038. ):
  1039. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1040. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1041. transformer_outputs = self.transformer(
  1042. input_ids=input_ids,
  1043. position_ids=position_ids,
  1044. attention_mask=attention_mask,
  1045. past_key_values=past_key_values,
  1046. inputs_embeds=inputs_embeds,
  1047. use_cache=use_cache,
  1048. output_attentions=output_attentions,
  1049. output_hidden_states=output_hidden_states,
  1050. return_dict=return_dict,
  1051. )
  1052. hidden_states = transformer_outputs[0]
  1053. lm_logits = self.lm_head(hidden_states).permute(1, 0, 2).contiguous()
  1054. loss = None
  1055. if labels is not None:
  1056. lm_logits = lm_logits.to(torch.float32)
  1057. # Shift so that tokens < n predict n
  1058. shift_logits = lm_logits[..., :-1, :].contiguous()
  1059. shift_labels = labels[..., 1:].contiguous()
  1060. # Flatten the tokens
  1061. loss_fct = CrossEntropyLoss(ignore_index=-100)
  1062. shift_labels = shift_labels.to(shift_logits.device)
  1063. loss = loss_fct(
  1064. shift_logits.view(-1, shift_logits.size(-1)),
  1065. shift_labels.view(-1))
  1066. lm_logits = lm_logits.to(hidden_states.dtype)
  1067. loss = loss.to(hidden_states.dtype)
  1068. if not return_dict:
  1069. output = (lm_logits, ) + transformer_outputs[1:]
  1070. return ((loss, ) + output) if loss is not None else output
  1071. return CausalLMOutputWithPast(
  1072. loss=loss,
  1073. logits=lm_logits,
  1074. past_key_values=transformer_outputs.past_key_values,
  1075. hidden_states=transformer_outputs.hidden_states,
  1076. attentions=transformer_outputs.attentions,
  1077. )
  1078. @staticmethod
  1079. def _reorder_cache(
  1080. past: Tuple[Tuple[torch.Tensor, torch.Tensor],
  1081. ...], beam_idx: torch.LongTensor
  1082. ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
  1083. """
  1084. This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
  1085. [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
  1086. beam_idx at every generation step.
  1087. Output shares the same memory storage as `past`.
  1088. """
  1089. return tuple((
  1090. layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)),
  1091. layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)),
  1092. ) for layer_past in past)
  1093. def process_response(self, response):
  1094. response = response.strip()
  1095. response = response.replace('[[训练时间]]', '2023年')
  1096. punkts = [
  1097. [',', ','],
  1098. ['!', '!'],
  1099. [':', ':'],
  1100. [';', ';'],
  1101. ['\?', '?'], # noqa
  1102. ]
  1103. for item in punkts:
  1104. response = re.sub(r'([\u4e00-\u9fff])%s' % item[0],
  1105. r'\1%s' % item[1], response)
  1106. response = re.sub(r'%s([\u4e00-\u9fff])' % item[0],
  1107. r'%s\1' % item[1], response)
  1108. return response
  1109. @torch.no_grad()
  1110. def _chat(self,
  1111. tokenizer,
  1112. query: str,
  1113. history: List[Tuple[str, str]] = None,
  1114. max_length: int = 2048,
  1115. num_beams=1,
  1116. do_sample=True,
  1117. top_p=0.7,
  1118. temperature=0.95,
  1119. logits_processor=None,
  1120. **kwargs):
  1121. if history is None:
  1122. history = []
  1123. if logits_processor is None:
  1124. logits_processor = LogitsProcessorList()
  1125. logits_processor.append(InvalidScoreLogitsProcessor())
  1126. gen_kwargs = {
  1127. 'max_length': max_length,
  1128. 'num_beams': num_beams,
  1129. 'do_sample': do_sample,
  1130. 'top_p': top_p,
  1131. 'temperature': temperature,
  1132. 'logits_processor': logits_processor,
  1133. **kwargs
  1134. }
  1135. if not history:
  1136. prompt = query
  1137. else:
  1138. prompt = ''
  1139. for i, (old_query, response) in enumerate(history):
  1140. prompt += '[Round {}]\n问:{}\n答:{}\n'.format(
  1141. i, old_query, response)
  1142. prompt += '[Round {}]\n问:{}\n答:'.format(len(history), query)
  1143. inputs = tokenizer([prompt], return_tensors='pt')
  1144. inputs = inputs.to(self.device)
  1145. outputs = self.generate(**inputs, **gen_kwargs)
  1146. outputs = outputs.tolist()[0][len(inputs['input_ids'][0]):]
  1147. response = tokenizer.decode(outputs)
  1148. response = self.process_response(response)
  1149. history = history + [(query, response)]
  1150. return response, history
  1151. @torch.no_grad()
  1152. def stream_chat(self,
  1153. tokenizer,
  1154. query: str,
  1155. history: List[Tuple[str, str]] = None,
  1156. max_length: int = 2048,
  1157. do_sample=True,
  1158. top_p=0.7,
  1159. temperature=0.95,
  1160. logits_processor=None,
  1161. **kwargs):
  1162. if history is None:
  1163. history = []
  1164. if logits_processor is None:
  1165. logits_processor = LogitsProcessorList()
  1166. logits_processor.append(InvalidScoreLogitsProcessor())
  1167. gen_kwargs = {
  1168. 'max_length': max_length,
  1169. 'do_sample': do_sample,
  1170. 'top_p': top_p,
  1171. 'temperature': temperature,
  1172. 'logits_processor': logits_processor,
  1173. **kwargs
  1174. }
  1175. if not history:
  1176. prompt = query
  1177. else:
  1178. prompt = ''
  1179. for i, (old_query, response) in enumerate(history):
  1180. prompt += '[Round {}]\n问:{}\n答:{}\n'.format(
  1181. i, old_query, response)
  1182. prompt += '[Round {}]\n问:{}\n答:'.format(len(history), query)
  1183. inputs = tokenizer([prompt], return_tensors='pt')
  1184. inputs = inputs.to(self.device)
  1185. for outputs in self.stream_generate(**inputs, **gen_kwargs):
  1186. outputs = outputs.tolist()[0][len(inputs['input_ids'][0]):]
  1187. response = tokenizer.decode(outputs)
  1188. response = self.process_response(response)
  1189. new_history = history + [(query, response)]
  1190. yield response, new_history
  1191. @torch.no_grad()
  1192. def stream_generate(
  1193. self,
  1194. input_ids,
  1195. generation_config: Optional[GenerationConfig] = None,
  1196. logits_processor: Optional[LogitsProcessorList] = None,
  1197. stopping_criteria: Optional[StoppingCriteriaList] = None,
  1198. prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor],
  1199. List[int]]] = None,
  1200. **kwargs,
  1201. ):
  1202. _, input_ids_seq_length = input_ids.shape[0], input_ids.shape[
  1203. -1] # noqa
  1204. if generation_config is None:
  1205. generation_config = self.generation_config
  1206. generation_config = copy.deepcopy(generation_config)
  1207. model_kwargs = generation_config.update(**kwargs)
  1208. _, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id
  1209. if isinstance(eos_token_id, int):
  1210. eos_token_id = [eos_token_id]
  1211. has_default_max_length = kwargs.get(
  1212. 'max_length') is None and generation_config.max_length is not None
  1213. if has_default_max_length and generation_config.max_new_tokens is None:
  1214. warnings.warn(
  1215. f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
  1216. 'This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we'
  1217. ' recommend using `max_new_tokens` to control the maximum length of the generation.',
  1218. UserWarning,
  1219. )
  1220. elif generation_config.max_new_tokens is not None:
  1221. generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
  1222. if not has_default_max_length:
  1223. logger.warn(
  1224. f'Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(='
  1225. f'{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. '
  1226. 'Please refer to the documentation for more information. '
  1227. '(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)',
  1228. UserWarning,
  1229. )
  1230. if input_ids_seq_length >= generation_config.max_length:
  1231. input_ids_string = 'decoder_input_ids' if self.config.is_encoder_decoder else 'input_ids'
  1232. logger.warning(
  1233. f'Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to'
  1234. f' {generation_config.max_length}. This can lead to unexpected behavior. You should consider'
  1235. ' increasing `max_new_tokens`.')
  1236. # 2. Set generation parameters if not already defined
  1237. logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList(
  1238. )
  1239. stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList(
  1240. )
  1241. logits_processor = self._get_logits_processor(
  1242. generation_config=generation_config,
  1243. input_ids_seq_length=input_ids_seq_length,
  1244. encoder_input_ids=input_ids,
  1245. prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
  1246. logits_processor=logits_processor,
  1247. )
  1248. stopping_criteria = self._get_stopping_criteria(
  1249. generation_config=generation_config,
  1250. stopping_criteria=stopping_criteria)
  1251. logits_warper = self._get_logits_warper(generation_config)
  1252. unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
  1253. scores = None
  1254. while True:
  1255. model_inputs = self.prepare_inputs_for_generation(
  1256. input_ids, **model_kwargs)
  1257. # forward pass to get next token
  1258. outputs = self(
  1259. **model_inputs,
  1260. return_dict=True,
  1261. output_attentions=False,
  1262. output_hidden_states=False,
  1263. )
  1264. next_token_logits = outputs.logits[:, -1, :]
  1265. # pre-process distribution
  1266. next_token_scores = logits_processor(input_ids, next_token_logits)
  1267. next_token_scores = logits_warper(input_ids, next_token_scores)
  1268. # sample
  1269. probs = nn.functional.softmax(next_token_scores, dim=-1)
  1270. if generation_config.do_sample:
  1271. next_tokens = torch.multinomial(
  1272. probs, num_samples=1).squeeze(1)
  1273. else:
  1274. next_tokens = torch.argmax(probs, dim=-1)
  1275. # update generated ids, model inputs, and length for next step
  1276. input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
  1277. model_kwargs = self._update_model_kwargs_for_generation(
  1278. outputs,
  1279. model_kwargs,
  1280. is_encoder_decoder=self.config.is_encoder_decoder)
  1281. unfinished_sequences = unfinished_sequences.mul(
  1282. (sum(next_tokens != i for i in eos_token_id)).long())
  1283. # stop when each sentence is finished, or if we exceed the maximum length
  1284. if unfinished_sequences.max() == 0 or stopping_criteria(
  1285. input_ids, scores):
  1286. break
  1287. yield input_ids
  1288. def quantize(self, bits: int, empty_init=False, **kwargs):
  1289. if bits == 0:
  1290. return
  1291. from .quantization import quantize
  1292. if self.quantized:
  1293. logger.info('Already quantized.')
  1294. return self
  1295. self.quantized = True
  1296. self.config.quantization_bit = bits
  1297. self.transformer = quantize(
  1298. self.transformer, bits, empty_init=empty_init, **kwargs)
  1299. return self
  1300. def chat(self, input: Dict) -> Dict:
  1301. text = input['text']
  1302. history = input['history']
  1303. # args
  1304. if 'max_length' in input:
  1305. max_length = input['max_length']
  1306. else:
  1307. max_length = 2048
  1308. if 'temperature' in input:
  1309. temperature = input['temperature']
  1310. else:
  1311. temperature = 0.95
  1312. if 'num_beams' in input:
  1313. num_beams = input['num_beams']
  1314. else:
  1315. num_beams = 1
  1316. if 'do_sample' in input:
  1317. do_sample = input['do_sample']
  1318. else:
  1319. do_sample = True
  1320. if type(history) == torch.Tensor:
  1321. history = history.tolist()
  1322. response, history = self._chat(
  1323. self.tokenizer,
  1324. text,
  1325. history,
  1326. max_length=max_length,
  1327. temperature=temperature,
  1328. num_beams=num_beams,
  1329. do_sample=do_sample)
  1330. logger.info('Generation finished.')
  1331. return {OutputKeys.RESPONSE: response, OutputKeys.HISTORY: history}