| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572 |
- """ PyTorch ChatGLM model. """
- import copy
- import math
- import os
- import re
- import sys
- import warnings
- from typing import Any, Callable, Dict, List, Optional, Tuple, Union
- import torch
- import torch.nn.functional as F
- import torch.utils.checkpoint
- from torch import nn
- from torch.nn import CrossEntropyLoss, LayerNorm
- from torch.nn.utils import skip_init
- from transformers.generation.logits_process import LogitsProcessor
- from transformers.generation.utils import (GenerationConfig,
- LogitsProcessorList, ModelOutput,
- StoppingCriteriaList)
- from transformers.modeling_outputs import (
- BaseModelOutputWithPast, BaseModelOutputWithPastAndCrossAttentions,
- CausalLMOutputWithPast)
- from transformers.modeling_utils import PreTrainedModel
- from transformers.utils import (add_code_sample_docstrings,
- add_start_docstrings,
- add_start_docstrings_to_model_forward)
- from modelscope.metainfo import Models
- from modelscope.models import MODELS, Model, TorchModel
- from modelscope.outputs import OutputKeys
- from modelscope.utils import logger as logging
- from modelscope.utils.constant import Tasks
- from .configuration import ChatGLMConfig
- from .tokenization import ChatGLMTokenizer
- # flags required to enable jit fusion kernels
- if sys.platform != 'darwin':
- torch._C._jit_set_profiling_mode(False)
- torch._C._jit_set_profiling_executor(False)
- torch._C._jit_override_can_fuse_on_cpu(True)
- torch._C._jit_override_can_fuse_on_gpu(True)
- logger = logging.get_logger()
- _CHECKPOINT_FOR_DOC = 'THUDM/ChatGLM-6B'
- _CONFIG_FOR_DOC = 'ChatGLM6BConfig'
- CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [
- 'THUDM/chatglm-6b',
- # See all ChatGLM-6B models at https://huggingface.co/models?filter=chatglm
- ]
- class InvalidScoreLogitsProcessor(LogitsProcessor):
- def __call__(self, input_ids: torch.LongTensor,
- scores: torch.FloatTensor) -> torch.FloatTensor:
- if torch.isnan(scores).any() or torch.isinf(scores).any():
- scores.zero_()
- scores[..., 5] = 5e4
- return scores
- def load_tf_weights_in_chatglm_6b(model, config, tf_checkpoint_path):
- """Load tf checkpoints in a pytorch model."""
- try:
- import re
- import numpy as np
- import tensorflow as tf
- except ImportError:
- logger.error(
- 'Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see '
- 'https://www.tensorflow.org/install/ for installation instructions.'
- )
- raise
- tf_path = os.path.abspath(tf_checkpoint_path)
- logger.info(f'Converting TensorFlow checkpoint from {tf_path}')
- # Load weights from TF model
- init_vars = tf.train.list_variables(tf_path)
- names = []
- arrays = []
- for name, shape in init_vars:
- logger.info(f'Loading TF weight {name} with shape {shape}')
- array = tf.train.load_variable(tf_path, name)
- names.append(name)
- arrays.append(array)
- for name, array in zip(names, arrays):
- name = name.split('/')
- # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
- # which are not required for using pretrained model
- if any(n in [
- 'adam_v', 'adam_m', 'AdamWeightDecayOptimizer',
- 'AdamWeightDecayOptimizer_1', 'global_step'
- ] for n in name):
- logger.info(f"Skipping {'/'.join(name)}")
- continue
- pointer = model
- for m_name in name:
- if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
- scope_names = re.split(r'_(\d+)', m_name)
- else:
- scope_names = [m_name]
- if scope_names[0] == 'kernel' or scope_names[0] == 'gamma':
- pointer = getattr(pointer, 'weight')
- elif scope_names[0] == 'output_bias' or scope_names[0] == 'beta':
- pointer = getattr(pointer, 'bias')
- elif scope_names[0] == 'output_weights':
- pointer = getattr(pointer, 'weight')
- elif scope_names[0] == 'squad':
- pointer = getattr(pointer, 'classifier')
- else:
- try:
- pointer = getattr(pointer, scope_names[0])
- except AttributeError:
- logger.info(f"Skipping {'/'.join(name)}")
- continue
- if len(scope_names) >= 2:
- num = int(scope_names[1])
- pointer = pointer[num]
- if m_name[-11:] == '_embeddings':
- pointer = getattr(pointer, 'weight')
- elif m_name == 'kernel':
- array = np.transpose(array)
- try:
- assert (
- pointer.shape == array.shape
- ), f'Pointer shape {pointer.shape} and array shape {array.shape} mismatched'
- except AssertionError as e:
- e.args += (pointer.shape, array.shape)
- raise
- logger.info(f'Initialize PyTorch weight {name}')
- pointer.data = torch.from_numpy(array)
- return model
- class PrefixEncoder(torch.nn.Module):
- """
- The torch.nn model to encode the prefix
- Input shape: (batch-size, prefix-length)
- Output shape: (batch-size, prefix-length, 2*layers*hidden)
- """
- def __init__(self, config):
- super().__init__()
- self.prefix_projection = config.prefix_projection
- if self.prefix_projection:
- # Use a two-layer MLP to encode the prefix
- self.embedding = torch.nn.Embedding(config.pre_seq_len,
- config.hidden_size)
- self.trans = torch.nn.Sequential(
- torch.nn.Linear(config.hidden_size, config.hidden_size),
- torch.nn.Tanh(),
- torch.nn.Linear(config.hidden_size,
- config.num_layers * config.hidden_size * 2))
- else:
- self.embedding = torch.nn.Embedding(
- config.pre_seq_len, config.num_layers * config.hidden_size * 2)
- def forward(self, prefix: torch.Tensor):
- if self.prefix_projection:
- prefix_tokens = self.embedding(prefix)
- past_key_values = self.trans(prefix_tokens)
- else:
- past_key_values = self.embedding(prefix)
- return past_key_values
- @torch.jit.script
- def gelu_impl(x):
- """OpenAI's gelu implementation."""
- return 0.5 * x * (
- 1.0 + torch.tanh(0.7978845608028654 * x * # noqa
- (1.0 + 0.044715 * x * x))) # noqa
- def gelu(x):
- return gelu_impl(x)
- class RotaryEmbedding(torch.nn.Module):
- def __init__(self, dim, base=10000, precision=torch.half, learnable=False):
- super().__init__()
- inv_freq = 1. / (base**(torch.arange(0, dim, 2).float() / dim))
- inv_freq = inv_freq.half()
- self.learnable = learnable
- if learnable:
- self.inv_freq = torch.nn.Parameter(inv_freq)
- self.max_seq_len_cached = None
- else:
- self.register_buffer('inv_freq', inv_freq)
- self.max_seq_len_cached = None
- self.cos_cached = None
- self.sin_cached = None
- self.precision = precision
- def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
- missing_keys, unexpected_keys, error_msgs):
- pass
- def forward(self, x, seq_dim=1, seq_len=None):
- if seq_len is None:
- seq_len = x.shape[seq_dim]
- if self.max_seq_len_cached is None or (
- seq_len > self.max_seq_len_cached): # noqa
- self.max_seq_len_cached = None if self.learnable else seq_len
- t = torch.arange(
- seq_len, device=x.device, dtype=self.inv_freq.dtype)
- freqs = torch.einsum('i,j->ij', t, self.inv_freq)
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
- emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
- if self.precision == torch.bfloat16:
- emb = emb.float()
- # [sx, 1 (b * np), hn]
- cos_cached = emb.cos()[:, None, :]
- sin_cached = emb.sin()[:, None, :]
- if self.precision == torch.bfloat16:
- cos_cached = cos_cached.bfloat16()
- sin_cached = sin_cached.bfloat16()
- if self.learnable:
- return cos_cached, sin_cached
- self.cos_cached, self.sin_cached = cos_cached, sin_cached
- return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...]
- def _apply(self, fn):
- if self.cos_cached is not None:
- self.cos_cached = fn(self.cos_cached)
- if self.sin_cached is not None:
- self.sin_cached = fn(self.sin_cached)
- return super()._apply(fn)
- def rotate_half(x):
- x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
- return torch.cat(
- (-x2, x1),
- dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions
- @torch.jit.script
- def apply_rotary_pos_emb_index(q, k, cos, sin, position_id):
- # position_id: [sq, b], q, k: [sq, b, np, hn], cos: [sq, 1, hn] -> [sq, b, 1, hn]
- cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), \
- F.embedding(position_id, sin.squeeze(1)).unsqueeze(2)
- q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (
- rotate_half(k) * sin)
- return q, k
- def attention_fn(
- self,
- query_layer,
- key_layer,
- value_layer,
- attention_mask,
- hidden_size_per_partition,
- layer_id,
- layer_past=None,
- scaling_attention_score=True,
- use_cache=False,
- ):
- if layer_past is not None:
- past_key, past_value = layer_past[0], layer_past[1]
- key_layer = torch.cat((past_key, key_layer), dim=0)
- value_layer = torch.cat((past_value, value_layer), dim=0)
- # seqlen, batch, num_attention_heads, hidden_size_per_attention_head
- seq_len, b, nh, hidden_size = key_layer.shape
- if use_cache:
- present = (key_layer, value_layer)
- else:
- present = None
- query_key_layer_scaling_coeff = float(layer_id + 1)
- if scaling_attention_score:
- query_layer = query_layer / (
- math.sqrt(hidden_size) * query_key_layer_scaling_coeff)
- # ===================================
- # Raw attention scores. [b, np, s, s]
- # ===================================
- # [b, np, sq, sk]
- output_size = (query_layer.size(1), query_layer.size(2),
- query_layer.size(0), key_layer.size(0))
- # [sq, b, np, hn] -> [sq, b * np, hn]
- query_layer = query_layer.view(output_size[2],
- output_size[0] * output_size[1], -1)
- # [sk, b, np, hn] -> [sk, b * np, hn]
- key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1],
- -1)
- matmul_result = torch.zeros(
- 1,
- 1,
- 1,
- dtype=query_layer.dtype,
- device=query_layer.device,
- )
- matmul_result = torch.baddbmm(
- matmul_result,
- query_layer.transpose(0, 1), # [b * np, sq, hn]
- key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
- beta=0.0,
- alpha=1.0,
- )
- # change view to [b, np, sq, sk]
- attention_scores = matmul_result.view(*output_size)
- if self.scale_mask_softmax:
- self.scale_mask_softmax.scale = query_key_layer_scaling_coeff
- attention_probs = self.scale_mask_softmax(attention_scores,
- attention_mask.contiguous())
- else:
- if not (attention_mask == 0).all():
- # if auto-regressive, skip
- attention_scores.masked_fill_(attention_mask, -10000.0)
- dtype = attention_scores.dtype
- attention_scores = attention_scores.float()
- attention_scores = attention_scores * query_key_layer_scaling_coeff
- attention_probs = F.softmax(attention_scores, dim=-1)
- attention_probs = attention_probs.type(dtype)
- # =========================
- # Context layer. [sq, b, hp]
- # =========================
- # value_layer -> context layer.
- # [sk, b, np, hn] --> [b, np, sq, hn]
- # context layer shape: [b, np, sq, hn]
- output_size = (value_layer.size(1), value_layer.size(2),
- query_layer.size(0), value_layer.size(3))
- # change view [sk, b * np, hn]
- value_layer = value_layer.view(
- value_layer.size(0), output_size[0] * output_size[1], -1)
- # change view [b * np, sq, sk]
- attention_probs = attention_probs.view(output_size[0] * output_size[1],
- output_size[2], -1)
- # matmul: [b * np, sq, hn]
- context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
- # change view [b, np, sq, hn]
- context_layer = context_layer.view(*output_size)
- # [b, np, sq, hn] --> [sq, b, np, hn]
- context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
- # [sq, b, np, hn] --> [sq, b, hp]
- new_context_layer_shape = context_layer.size()[:-2] + (
- hidden_size_per_partition, )
- context_layer = context_layer.view(*new_context_layer_shape)
- outputs = (context_layer, present, attention_probs)
- return outputs
- class SelfAttention(torch.nn.Module):
- def __init__(self,
- hidden_size,
- num_attention_heads,
- layer_id,
- hidden_size_per_attention_head=None,
- bias=True,
- params_dtype=torch.float,
- position_encoding_2d=True):
- super(SelfAttention, self).__init__()
- self.layer_id = layer_id
- self.hidden_size = hidden_size
- self.hidden_size_per_partition = hidden_size
- self.num_attention_heads = num_attention_heads
- self.num_attention_heads_per_partition = num_attention_heads
- self.position_encoding_2d = position_encoding_2d
- self.rotary_emb = RotaryEmbedding( # noqa
- self.hidden_size // # noqa
- (self.num_attention_heads * 2) if position_encoding_2d else # noqa
- self.hidden_size // self.num_attention_heads, # noqa
- base=10000, # noqa
- precision=torch.half, # noqa
- learnable=False, # noqa
- ) # noqa
- self.scale_mask_softmax = None
- if hidden_size_per_attention_head is None:
- self.hidden_size_per_attention_head = hidden_size // num_attention_heads
- else:
- self.hidden_size_per_attention_head = hidden_size_per_attention_head
- self.inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head
- # Strided linear layer.
- self.query_key_value = skip_init(
- torch.nn.Linear,
- hidden_size,
- 3 * self.inner_hidden_size,
- bias=bias,
- dtype=params_dtype,
- )
- self.dense = skip_init(
- torch.nn.Linear,
- self.inner_hidden_size,
- hidden_size,
- bias=bias,
- dtype=params_dtype,
- )
- @staticmethod
- def attention_mask_func(attention_scores, attention_mask):
- attention_scores.masked_fill_(attention_mask, -10000.0)
- return attention_scores
- def split_tensor_along_last_dim(self,
- tensor,
- num_partitions,
- contiguous_split_chunks=False):
- """Split a tensor along its last dimension.
- Arguments:
- tensor: input tensor.
- num_partitions: number of partitions to split the tensor
- contiguous_split_chunks: If True, make each chunk contiguous
- in memory.
- """
- # Get the size and dimension.
- last_dim = tensor.dim() - 1
- last_dim_size = tensor.size()[last_dim] // num_partitions
- # Split.
- tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
- # Note: torch.split does not create contiguous tensors by default.
- if contiguous_split_chunks:
- return tuple(chunk.contiguous() for chunk in tensor_list)
- return tensor_list
- def forward(
- self,
- hidden_states: torch.Tensor,
- position_ids,
- attention_mask: torch.Tensor,
- layer_id,
- layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
- use_cache: bool = False,
- output_attentions: bool = False,
- ):
- """
- hidden_states: [seq_len, batch, hidden_size]
- attention_mask: [(1, 1), seq_len, seq_len]
- """
- # [seq_len, batch, 3 * hidden_size]
- mixed_raw_layer = self.query_key_value(hidden_states)
- # [seq_len, batch, 3 * hidden_size] --> [seq_len, batch, num_attention_heads,
- # 3 * hidden_size_per_attention_head]
- new_tensor_shape = mixed_raw_layer.size()[:-1] + (
- self.num_attention_heads_per_partition,
- 3 * self.hidden_size_per_attention_head,
- )
- mixed_raw_layer = mixed_raw_layer.view(*new_tensor_shape)
- # [seq_len, batch, num_attention_heads, hidden_size_per_attention_head]
- (query_layer, key_layer,
- value_layer) = self.split_tensor_along_last_dim(mixed_raw_layer, 3)
- if self.position_encoding_2d:
- q1, q2 = query_layer.chunk(2, dim=(query_layer.ndim - 1))
- k1, k2 = key_layer.chunk(2, dim=(key_layer.ndim - 1))
- cos, sin = self.rotary_emb(q1, seq_len=position_ids.max() + 1)
- position_ids, block_position_ids = position_ids[:, 0, :].transpose(0, 1).contiguous(), \
- position_ids[:, 1, :].transpose(0, 1).contiguous()
- q1, k1 = apply_rotary_pos_emb_index(q1, k1, cos, sin, position_ids)
- q2, k2 = apply_rotary_pos_emb_index(q2, k2, cos, sin,
- block_position_ids)
- query_layer = torch.concat([q1, q2], dim=(q1.ndim - 1))
- key_layer = torch.concat([k1, k2], dim=(k1.ndim - 1))
- else:
- position_ids = position_ids.transpose(0, 1)
- cos, sin = self.rotary_emb(
- value_layer, seq_len=position_ids.max() + 1)
- # [seq_len, batch, num_attention_heads, hidden_size_per_attention_head]
- query_layer, key_layer = apply_rotary_pos_emb_index(
- query_layer, key_layer, cos, sin, position_ids)
- # [seq_len, batch, hidden_size]
- context_layer, present, attention_probs = attention_fn(
- self=self,
- query_layer=query_layer,
- key_layer=key_layer,
- value_layer=value_layer,
- attention_mask=attention_mask,
- hidden_size_per_partition=self.hidden_size_per_partition,
- layer_id=layer_id,
- layer_past=layer_past,
- use_cache=use_cache)
- output = self.dense(context_layer)
- outputs = (output, present)
- if output_attentions:
- outputs += (attention_probs, )
- return outputs # output, present, attention_probs
- class GEGLU(torch.nn.Module):
- def __init__(self):
- super().__init__()
- self.activation_fn = F.gelu
- def forward(self, x):
- # dim=-1 breaks in jit for pt<1.10
- x1, x2 = x.chunk(2, dim=(x.ndim - 1))
- return x1 * self.activation_fn(x2)
- class GLU(torch.nn.Module):
- def __init__(self,
- hidden_size,
- inner_hidden_size=None,
- layer_id=None,
- bias=True,
- activation_func=gelu,
- params_dtype=torch.float):
- super(GLU, self).__init__()
- self.layer_id = layer_id
- self.activation_func = activation_func
- # Project to 4h.
- self.hidden_size = hidden_size
- if inner_hidden_size is None:
- inner_hidden_size = 4 * hidden_size
- self.inner_hidden_size = inner_hidden_size
- self.dense_h_to_4h = skip_init(
- torch.nn.Linear,
- self.hidden_size,
- self.inner_hidden_size,
- bias=bias,
- dtype=params_dtype,
- )
- # Project back to h.
- self.dense_4h_to_h = skip_init(
- torch.nn.Linear,
- self.inner_hidden_size,
- self.hidden_size,
- bias=bias,
- dtype=params_dtype,
- )
- def forward(self, hidden_states):
- """
- hidden_states: [seq_len, batch, hidden_size]
- """
- # [seq_len, batch, inner_hidden_size]
- intermediate_parallel = self.dense_h_to_4h(hidden_states)
- intermediate_parallel = self.activation_func(intermediate_parallel)
- output = self.dense_4h_to_h(intermediate_parallel)
- return output
- class GLMBlock(torch.nn.Module):
- def __init__(self,
- hidden_size,
- num_attention_heads,
- layernorm_epsilon,
- layer_id,
- inner_hidden_size=None,
- hidden_size_per_attention_head=None,
- layernorm=LayerNorm,
- use_bias=True,
- params_dtype=torch.float,
- num_layers=28,
- position_encoding_2d=True):
- super(GLMBlock, self).__init__()
- # Set output layer initialization if not provided.
- self.layer_id = layer_id
- # Layernorm on the input data.
- self.input_layernorm = layernorm(hidden_size, eps=layernorm_epsilon)
- self.position_encoding_2d = position_encoding_2d
- # Self attention.
- self.attention = SelfAttention(
- hidden_size,
- num_attention_heads,
- layer_id,
- hidden_size_per_attention_head=hidden_size_per_attention_head,
- bias=use_bias,
- params_dtype=params_dtype,
- position_encoding_2d=self.position_encoding_2d)
- # Layernorm on the input data.
- self.post_attention_layernorm = layernorm(
- hidden_size, eps=layernorm_epsilon)
- self.num_layers = num_layers
- # GLU
- self.mlp = GLU(
- hidden_size,
- inner_hidden_size=inner_hidden_size,
- bias=use_bias,
- layer_id=layer_id,
- params_dtype=params_dtype,
- )
- def forward(
- self,
- hidden_states: torch.Tensor,
- position_ids,
- attention_mask: torch.Tensor,
- layer_id,
- layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
- use_cache: bool = False,
- output_attentions: bool = False,
- ):
- """
- hidden_states: [seq_len, batch, hidden_size]
- attention_mask: [(1, 1), seq_len, seq_len]
- """
- # Layer norm at the beginning of the transformer layer.
- # [seq_len, batch, hidden_size]
- attention_input = self.input_layernorm(hidden_states)
- # Self attention.
- attention_outputs = self.attention(
- attention_input,
- position_ids,
- attention_mask=attention_mask,
- layer_id=layer_id,
- layer_past=layer_past,
- use_cache=use_cache,
- output_attentions=output_attentions)
- attention_output = attention_outputs[0]
- outputs = attention_outputs[1:]
- # Residual connection.
- alpha = (2 * self.num_layers)**0.5
- hidden_states = attention_input * alpha + attention_output
- mlp_input = self.post_attention_layernorm(hidden_states)
- # MLP.
- mlp_output = self.mlp(mlp_input)
- # Second residual connection.
- output = mlp_input * alpha + mlp_output
- if use_cache:
- outputs = (output, ) + outputs
- else:
- outputs = (output, ) + outputs[1:]
- return outputs # hidden_states, present, attentions
- class ChatGLMPreTrainedModel(TorchModel, PreTrainedModel):
- """
- An abstract class to handle weights initialization and
- a simple interface for downloading and loading pretrained models.
- """
- is_parallelizable = False
- supports_gradient_checkpointing = True
- config_class = ChatGLMConfig
- base_model_prefix = 'transformer'
- _no_split_modules = ['GLMBlock']
- def __init__(self, config, **kwargs):
- super().__init__(config.name_or_path, **kwargs)
- super(Model, self).__init__(config)
- def _init_weights(self, module: nn.Module):
- """Initialize the weights."""
- return
- def get_masks(self, input_ids, device):
- batch_size, seq_length = input_ids.shape
- context_lengths = [
- seq.tolist().index(self.config.bos_token_id) for seq in input_ids
- ]
- attention_mask = torch.ones((batch_size, seq_length, seq_length),
- device=device)
- attention_mask.tril_()
- for i, context_length in enumerate(context_lengths):
- attention_mask[i, :, :context_length] = 1
- attention_mask.unsqueeze_(1)
- attention_mask = (attention_mask < 0.5).bool()
- return attention_mask
- def get_position_ids(self, input_ids, mask_positions, device, gmask=False):
- batch_size, seq_length = input_ids.shape
- context_lengths = [
- seq.tolist().index(self.config.bos_token_id) for seq in input_ids
- ]
- if self.position_encoding_2d:
- position_ids = torch.arange(
- seq_length, dtype=torch.long,
- device=device).unsqueeze(0).repeat(batch_size, 1)
- for i, context_length in enumerate(context_lengths):
- position_ids[i, context_length:] = mask_positions[i]
- block_position_ids = [
- torch.cat((
- torch.zeros( # noqa
- context_length,
- dtype=torch.long,
- device=device), # noqa
- torch.arange( # noqa
- seq_length - context_length, # noqa
- dtype=torch.long, # noqa
- device=device) + 1)) # noqa
- for context_length in context_lengths
- ]
- block_position_ids = torch.stack(block_position_ids, dim=0)
- position_ids = torch.stack((position_ids, block_position_ids),
- dim=1)
- else:
- position_ids = torch.arange(
- seq_length, dtype=torch.long,
- device=device).unsqueeze(0).repeat(batch_size, 1)
- if not gmask:
- for i, context_length in enumerate(context_lengths):
- position_ids[context_length:] = mask_positions[i]
- return position_ids
- def _set_gradient_checkpointing(self, module, value=False):
- if isinstance(module, ChatGLMModel):
- module.gradient_checkpointing = value
- @classmethod
- def _instantiate(cls, **kwargs):
- """Instantiate the model.
- Args:
- kwargs: Input args.
- model_dir: The model dir used to load the checkpoint and the label information.
- Returns:
- The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained
- """
- model_dir = kwargs.pop('model_dir', None)
- kwargs.pop('cfg', None)
- model = super(Model, cls).from_pretrained(
- pretrained_model_name_or_path=model_dir, **kwargs)
- model.model_dir = model_dir
- return model
- CHATGLM_6B_START_DOCSTRING = r"""
- This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class.
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
- usage and behavior.
- Parameters:
- config ([`~ChatGLM6BConfig`]): Model configuration class with all the parameters of the model.
- Initializing with a config file does not load the weights associated with the model, only the configuration.
- Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
- """
- CHATGLM_6B_INPUTS_DOCSTRING = r"""
- Args:
- input_ids (`torch.LongTensor` of shape `({0})`):
- Indices of input sequence tokens in the vocabulary.
- Indices can be obtained using [`ChatGLM6BTokenizer`].
- See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
- [What are input IDs?](../glossary#input-ids)
- attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- - 1 for tokens that are **not masked**,
- - 0 for tokens that are **masked**.
- [What are attention masks?](../glossary#attention-mask)
- token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
- Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, 1]`:
- - 0 corresponds to a *sentence A* token,
- - 1 corresponds to a *sentence B* token.
- [What are token type IDs?](../glossary#token-type-ids)
- position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
- Indices of positions of each input sequence tokens in the position embeddings.
- Selected in the range `[0, config.max_position_embeddings - 1]`.
- [What are position IDs?](../glossary#position-ids)
- head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
- Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
- - 1 indicates the head is **not masked**,
- - 0 indicates the head is **masked**.
- inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
- This is useful if you want more control over how to convert *input_ids* indices into associated vectors
- than the model's internal embedding lookup matrix.
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
- tensors for more detail.
- output_hidden_states (`bool`, *optional*):
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
- more detail.
- return_dict (`bool`, *optional*):
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
- """
- @add_start_docstrings(
- 'The bare ChatGLM-6B Model transformer outputting raw hidden-states without any specific head on top.',
- CHATGLM_6B_START_DOCSTRING,
- )
- class ChatGLMModel(ChatGLMPreTrainedModel):
- """
- The model can behave as an encoder (with only self-attention) as well
- as a decoder, in which case a layer of cross-attention is added between
- the self-attention layers, following the architecture described in [Attention is
- all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani,
- Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
- To behave as an decoder the model needs to be initialized with the
- `is_decoder` argument of the configuration set to `True`.
- To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder`
- argument and `add_cross_attention` set to `True`; an
- `encoder_hidden_states` is then expected as an input to the forward pass.
- """
- def __init__(self, config: ChatGLMConfig):
- super().__init__(config)
- # recording parameters
- self.max_sequence_length = config.max_sequence_length
- self.hidden_size = config.hidden_size
- self.params_dtype = torch.half
- self.num_attention_heads = config.num_attention_heads
- self.vocab_size = config.vocab_size
- self.num_layers = config.num_layers
- self.layernorm_epsilon = config.layernorm_epsilon
- self.inner_hidden_size = config.inner_hidden_size
- self.hidden_size_per_attention_head = self.hidden_size // self.num_attention_heads
- self.position_encoding_2d = config.position_encoding_2d
- self.pre_seq_len = config.pre_seq_len
- self.prefix_projection = config.prefix_projection
- self.word_embeddings = skip_init(
- torch.nn.Embedding,
- num_embeddings=self.vocab_size,
- embedding_dim=self.hidden_size,
- dtype=self.params_dtype)
- self.gradient_checkpointing = False
- def get_layer(layer_id):
- return GLMBlock(
- self.hidden_size,
- self.num_attention_heads,
- self.layernorm_epsilon,
- layer_id,
- inner_hidden_size=self.inner_hidden_size,
- hidden_size_per_attention_head=self.
- hidden_size_per_attention_head,
- layernorm=LayerNorm,
- use_bias=True,
- params_dtype=self.params_dtype,
- position_encoding_2d=self.position_encoding_2d,
- )
- self.layers = torch.nn.ModuleList(
- [get_layer(layer_id) for layer_id in range(self.num_layers)])
- # Final layer norm before output.
- self.final_layernorm = LayerNorm(
- self.hidden_size, eps=self.layernorm_epsilon)
- if self.pre_seq_len is not None:
- for param in self.parameters():
- param.requires_grad = False
- self.prefix_tokens = torch.arange(self.pre_seq_len).long()
- self.prefix_encoder = PrefixEncoder(config)
- self.dropout = torch.nn.Dropout(0.1)
- # total_params = sum(p.numel() for p in self.parameters())
- # trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
- # print("Using p-tuning v2: # trainable_params = {} / {}".format(trainable_params, total_params))
- def get_input_embeddings(self):
- return self.word_embeddings
- def set_input_embeddings(self, new_embeddings: torch.Tensor):
- self.word_embeddings = new_embeddings
- def get_prompt(self, batch_size, device, dtype=torch.half):
- prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size,
- -1).to(device)
- past_key_values = self.prefix_encoder(prefix_tokens).type(dtype)
- past_key_values = past_key_values.view(
- batch_size, self.pre_seq_len, self.num_layers * 2,
- self.num_attention_heads,
- self.hidden_size // self.num_attention_heads)
- # seq_len, b, nh, hidden_size
- past_key_values = self.dropout(past_key_values)
- past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2)
- # past_key_values = [(v[0], v[1]) for v in past_key_values]
- return past_key_values
- @add_start_docstrings_to_model_forward(
- CHATGLM_6B_INPUTS_DOCSTRING.format('batch_size, sequence_length'))
- @add_code_sample_docstrings(
- checkpoint=_CHECKPOINT_FOR_DOC,
- output_type=BaseModelOutputWithPastAndCrossAttentions,
- config_class=_CONFIG_FOR_DOC,
- )
- def forward(
- self,
- input_ids: Optional[torch.LongTensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor],
- ...]] = None,
- inputs_embeds: Optional[torch.LongTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPast]:
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else
- self.config.output_hidden_states)
- use_cache = use_cache if use_cache is not None else self.config.use_cache
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- if self.gradient_checkpointing and self.training:
- if use_cache:
- # logger.warning_once(
- # "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
- # )
- use_cache = False
- if input_ids is not None and inputs_embeds is not None:
- raise ValueError(
- 'You cannot specify both input_ids and inputs_embeds at the same time'
- )
- elif input_ids is not None:
- batch_size, seq_length = input_ids.shape[:2]
- elif inputs_embeds is not None:
- batch_size, seq_length, _ = inputs_embeds.shape[:2]
- else:
- raise ValueError(
- 'You have to specify either input_ids or inputs_embeds')
- if inputs_embeds is None:
- inputs_embeds = self.word_embeddings(input_ids)
- if past_key_values is None:
- if self.pre_seq_len is not None:
- past_key_values = self.get_prompt(
- batch_size=input_ids.shape[0],
- device=input_ids.device,
- dtype=inputs_embeds.dtype)
- else:
- past_key_values = tuple([None] * len(self.layers))
- if attention_mask is None:
- attention_mask = self.get_masks(
- input_ids, device=input_ids.device)
- if position_ids is None:
- MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
- mask_token = gMASK if gMASK in input_ids else MASK
- use_gmask = True if gMASK in input_ids else False
- mask_positions = [
- seq.tolist().index(mask_token) for seq in input_ids
- ]
- position_ids = self.get_position_ids(
- input_ids,
- mask_positions=mask_positions,
- device=input_ids.device,
- gmask=use_gmask)
- if self.pre_seq_len is not None and attention_mask is not None:
- prefix_attention_mask = torch.ones(
- batch_size, 1, input_ids.size(-1),
- self.pre_seq_len).to(attention_mask.device)
- prefix_attention_mask = (prefix_attention_mask < 0.5).bool()
- attention_mask = torch.cat((prefix_attention_mask, attention_mask),
- dim=3)
- # [seq_len, batch, hidden_size]
- hidden_states = inputs_embeds.transpose(0, 1)
- presents = () if use_cache else None
- all_self_attentions = () if output_attentions else None
- all_hidden_states = () if output_hidden_states else None
- if attention_mask is None:
- attention_mask = torch.zeros(1, 1, device=input_ids.device).bool()
- else:
- attention_mask = attention_mask.to(input_ids.device)
- for i, layer in enumerate(self.layers):
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states, )
- layer_past = past_key_values[i]
- if self.gradient_checkpointing and self.training:
- layer_ret = torch.utils.checkpoint.checkpoint(
- layer, hidden_states, position_ids, attention_mask,
- torch.tensor(i), layer_past, use_cache, output_attentions)
- else:
- layer_ret = layer(
- hidden_states,
- position_ids=position_ids,
- attention_mask=attention_mask,
- layer_id=torch.tensor(i),
- layer_past=layer_past,
- use_cache=use_cache,
- output_attentions=output_attentions)
- hidden_states = layer_ret[0]
- if use_cache:
- presents = presents + (layer_ret[1], )
- if output_attentions:
- all_self_attentions = all_self_attentions + (
- layer_ret[2 if use_cache else 1], )
- # Final layer norm.
- hidden_states = self.final_layernorm(hidden_states)
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states, )
- if not return_dict:
- return tuple(v for v in [
- hidden_states, presents, all_hidden_states, all_self_attentions
- ] if v is not None)
- return BaseModelOutputWithPast(
- last_hidden_state=hidden_states,
- past_key_values=presents,
- hidden_states=all_hidden_states,
- attentions=all_self_attentions,
- )
- @MODELS.register_module(Tasks.chat, module_name=Models.chatglm_6b)
- class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
- def __init__(self, config: ChatGLMConfig):
- super().__init__(config)
- # self.hidden_size = config.hidden_size
- # self.params_dtype = torch.half
- # self.vocab_size = config.vocab_size
- self.max_sequence_length = config.max_sequence_length
- self.position_encoding_2d = config.position_encoding_2d
- self.transformer = ChatGLMModel(config)
- self.lm_head = skip_init(
- nn.Linear,
- config.hidden_size,
- config.vocab_size,
- bias=False,
- dtype=torch.half)
- self.config = config
- self.quantized = False
- if self.config.quantization_bit:
- self.quantize(self.config.quantization_bit, empty_init=True)
- # loading tokenizer
- self.tokenizer = ChatGLMTokenizer.from_pretrained(config.name_or_path)
- def get_output_embeddings(self):
- return self.lm_head
- def set_output_embeddings(self, new_embeddings):
- self.lm_head = new_embeddings
- def _update_model_kwargs_for_generation(
- self,
- outputs: ModelOutput,
- model_kwargs: Dict[str, Any],
- is_encoder_decoder: bool = False,
- standardize_cache_format: bool = False,
- ) -> Dict[str, Any]:
- # update past_key_values
- model_kwargs['past_key_values'] = self._extract_past_from_model_output(
- outputs, standardize_cache_format=standardize_cache_format)
- # update attention mask
- if 'attention_mask' in model_kwargs:
- attention_mask = model_kwargs['attention_mask']
- if attention_mask is not None and attention_mask.dtype == torch.bool:
- attention_mask = torch.cat([
- attention_mask,
- attention_mask.new_ones((*attention_mask.shape[:3], 1))
- ],
- dim=3) # noqa
- new_attention_mask = attention_mask[:, :, -1:].clone()
- new_attention_mask[..., -1] = False
- model_kwargs['attention_mask'] = torch.cat(
- [attention_mask, new_attention_mask], dim=2)
- # update position ids
- if 'position_ids' in model_kwargs:
- position_ids = model_kwargs['position_ids']
- new_position_id = position_ids[..., -1:].clone()
- new_position_id[:, 1, :] += 1
- model_kwargs['position_ids'] = torch.cat(
- [position_ids, new_position_id], dim=-1)
- return model_kwargs
- def prepare_inputs_for_generation(
- self,
- input_ids: torch.LongTensor,
- past: Optional[torch.Tensor] = None,
- past_key_values: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.Tensor] = None,
- **kwargs) -> dict:
- batch_size, seq_length = input_ids.shape
- MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
- mask_token = gMASK if gMASK in input_ids else MASK
- use_gmask = True if gMASK in input_ids else False
- seqs = input_ids.tolist()
- mask_positions = [seq.index(mask_token) for seq in seqs]
- # only last token for input_ids if past is not None
- if past is not None or past_key_values is not None:
- last_token = input_ids[:, -1].unsqueeze(-1)
- if attention_mask is not None and attention_mask.dtype == torch.bool:
- attention_mask = attention_mask[:, :, -1:]
- else:
- attention_mask = None
- if position_ids is not None:
- position_ids = position_ids[..., -1:]
- else:
- context_lengths = [
- seq.index(self.config.bos_token_id) for seq in seqs
- ]
- if self.position_encoding_2d:
- position_ids = torch.tensor(
- [[mask_position, seq_length - context_length]
- for mask_position, context_length in zip(
- mask_positions, context_lengths)],
- dtype=torch.long,
- device=input_ids.device).unsqueeze(-1)
- else:
- position_ids = torch.tensor(
- [mask_position for mask_position in mask_positions],
- dtype=torch.long,
- device=input_ids.device).unsqueeze(-1)
- if past is None:
- past = past_key_values
- return {
- 'input_ids': last_token,
- 'past_key_values': past,
- 'position_ids': position_ids,
- 'attention_mask': attention_mask
- }
- else:
- if attention_mask is not None and attention_mask.dtype != torch.bool:
- # logger.warning_once(f"The dtype of attention mask ({attention_mask.dtype}) is not bool")
- attention_mask = None
- if attention_mask is None:
- attention_mask = self.get_masks(
- input_ids, device=input_ids.device)
- if position_ids is None:
- position_ids = self.get_position_ids(
- input_ids,
- device=input_ids.device,
- mask_positions=mask_positions,
- gmask=use_gmask)
- return {
- 'input_ids': input_ids,
- 'past_key_values': past,
- 'position_ids': position_ids,
- 'attention_mask': attention_mask
- }
- def forward(
- self,
- input_ids: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
- inputs_embeds: Optional[torch.Tensor] = None,
- labels: Optional[torch.Tensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ):
- use_cache = use_cache if use_cache is not None else self.config.use_cache
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- transformer_outputs = self.transformer(
- input_ids=input_ids,
- position_ids=position_ids,
- attention_mask=attention_mask,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- hidden_states = transformer_outputs[0]
- lm_logits = self.lm_head(hidden_states).permute(1, 0, 2).contiguous()
- loss = None
- if labels is not None:
- lm_logits = lm_logits.to(torch.float32)
- # Shift so that tokens < n predict n
- shift_logits = lm_logits[..., :-1, :].contiguous()
- shift_labels = labels[..., 1:].contiguous()
- # Flatten the tokens
- loss_fct = CrossEntropyLoss(ignore_index=-100)
- shift_labels = shift_labels.to(shift_logits.device)
- loss = loss_fct(
- shift_logits.view(-1, shift_logits.size(-1)),
- shift_labels.view(-1))
- lm_logits = lm_logits.to(hidden_states.dtype)
- loss = loss.to(hidden_states.dtype)
- if not return_dict:
- output = (lm_logits, ) + transformer_outputs[1:]
- return ((loss, ) + output) if loss is not None else output
- return CausalLMOutputWithPast(
- loss=loss,
- logits=lm_logits,
- past_key_values=transformer_outputs.past_key_values,
- hidden_states=transformer_outputs.hidden_states,
- attentions=transformer_outputs.attentions,
- )
- @staticmethod
- def _reorder_cache(
- past: Tuple[Tuple[torch.Tensor, torch.Tensor],
- ...], beam_idx: torch.LongTensor
- ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
- """
- This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
- [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
- beam_idx at every generation step.
- Output shares the same memory storage as `past`.
- """
- return tuple((
- layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)),
- layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)),
- ) for layer_past in past)
- def process_response(self, response):
- response = response.strip()
- response = response.replace('[[训练时间]]', '2023年')
- punkts = [
- [',', ','],
- ['!', '!'],
- [':', ':'],
- [';', ';'],
- ['\?', '?'], # noqa
- ]
- for item in punkts:
- response = re.sub(r'([\u4e00-\u9fff])%s' % item[0],
- r'\1%s' % item[1], response)
- response = re.sub(r'%s([\u4e00-\u9fff])' % item[0],
- r'%s\1' % item[1], response)
- return response
- @torch.no_grad()
- def _chat(self,
- tokenizer,
- query: str,
- history: List[Tuple[str, str]] = None,
- max_length: int = 2048,
- num_beams=1,
- do_sample=True,
- top_p=0.7,
- temperature=0.95,
- logits_processor=None,
- **kwargs):
- if history is None:
- history = []
- if logits_processor is None:
- logits_processor = LogitsProcessorList()
- logits_processor.append(InvalidScoreLogitsProcessor())
- gen_kwargs = {
- 'max_length': max_length,
- 'num_beams': num_beams,
- 'do_sample': do_sample,
- 'top_p': top_p,
- 'temperature': temperature,
- 'logits_processor': logits_processor,
- **kwargs
- }
- if not history:
- prompt = query
- else:
- prompt = ''
- for i, (old_query, response) in enumerate(history):
- prompt += '[Round {}]\n问:{}\n答:{}\n'.format(
- i, old_query, response)
- prompt += '[Round {}]\n问:{}\n答:'.format(len(history), query)
- inputs = tokenizer([prompt], return_tensors='pt')
- inputs = inputs.to(self.device)
- outputs = self.generate(**inputs, **gen_kwargs)
- outputs = outputs.tolist()[0][len(inputs['input_ids'][0]):]
- response = tokenizer.decode(outputs)
- response = self.process_response(response)
- history = history + [(query, response)]
- return response, history
- @torch.no_grad()
- def stream_chat(self,
- tokenizer,
- query: str,
- history: List[Tuple[str, str]] = None,
- max_length: int = 2048,
- do_sample=True,
- top_p=0.7,
- temperature=0.95,
- logits_processor=None,
- **kwargs):
- if history is None:
- history = []
- if logits_processor is None:
- logits_processor = LogitsProcessorList()
- logits_processor.append(InvalidScoreLogitsProcessor())
- gen_kwargs = {
- 'max_length': max_length,
- 'do_sample': do_sample,
- 'top_p': top_p,
- 'temperature': temperature,
- 'logits_processor': logits_processor,
- **kwargs
- }
- if not history:
- prompt = query
- else:
- prompt = ''
- for i, (old_query, response) in enumerate(history):
- prompt += '[Round {}]\n问:{}\n答:{}\n'.format(
- i, old_query, response)
- prompt += '[Round {}]\n问:{}\n答:'.format(len(history), query)
- inputs = tokenizer([prompt], return_tensors='pt')
- inputs = inputs.to(self.device)
- for outputs in self.stream_generate(**inputs, **gen_kwargs):
- outputs = outputs.tolist()[0][len(inputs['input_ids'][0]):]
- response = tokenizer.decode(outputs)
- response = self.process_response(response)
- new_history = history + [(query, response)]
- yield response, new_history
- @torch.no_grad()
- def stream_generate(
- self,
- input_ids,
- generation_config: Optional[GenerationConfig] = None,
- logits_processor: Optional[LogitsProcessorList] = None,
- stopping_criteria: Optional[StoppingCriteriaList] = None,
- prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor],
- List[int]]] = None,
- **kwargs,
- ):
- _, input_ids_seq_length = input_ids.shape[0], input_ids.shape[
- -1] # noqa
- if generation_config is None:
- generation_config = self.generation_config
- generation_config = copy.deepcopy(generation_config)
- model_kwargs = generation_config.update(**kwargs)
- _, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id
- if isinstance(eos_token_id, int):
- eos_token_id = [eos_token_id]
- has_default_max_length = kwargs.get(
- 'max_length') is None and generation_config.max_length is not None
- if has_default_max_length and generation_config.max_new_tokens is None:
- warnings.warn(
- f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
- 'This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we'
- ' recommend using `max_new_tokens` to control the maximum length of the generation.',
- UserWarning,
- )
- elif generation_config.max_new_tokens is not None:
- generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
- if not has_default_max_length:
- logger.warn(
- f'Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(='
- f'{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. '
- 'Please refer to the documentation for more information. '
- '(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)',
- UserWarning,
- )
- if input_ids_seq_length >= generation_config.max_length:
- input_ids_string = 'decoder_input_ids' if self.config.is_encoder_decoder else 'input_ids'
- logger.warning(
- f'Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to'
- f' {generation_config.max_length}. This can lead to unexpected behavior. You should consider'
- ' increasing `max_new_tokens`.')
- # 2. Set generation parameters if not already defined
- logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList(
- )
- stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList(
- )
- logits_processor = self._get_logits_processor(
- generation_config=generation_config,
- input_ids_seq_length=input_ids_seq_length,
- encoder_input_ids=input_ids,
- prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
- logits_processor=logits_processor,
- )
- stopping_criteria = self._get_stopping_criteria(
- generation_config=generation_config,
- stopping_criteria=stopping_criteria)
- logits_warper = self._get_logits_warper(generation_config)
- unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
- scores = None
- while True:
- model_inputs = self.prepare_inputs_for_generation(
- input_ids, **model_kwargs)
- # forward pass to get next token
- outputs = self(
- **model_inputs,
- return_dict=True,
- output_attentions=False,
- output_hidden_states=False,
- )
- next_token_logits = outputs.logits[:, -1, :]
- # pre-process distribution
- next_token_scores = logits_processor(input_ids, next_token_logits)
- next_token_scores = logits_warper(input_ids, next_token_scores)
- # sample
- probs = nn.functional.softmax(next_token_scores, dim=-1)
- if generation_config.do_sample:
- next_tokens = torch.multinomial(
- probs, num_samples=1).squeeze(1)
- else:
- next_tokens = torch.argmax(probs, dim=-1)
- # update generated ids, model inputs, and length for next step
- input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
- model_kwargs = self._update_model_kwargs_for_generation(
- outputs,
- model_kwargs,
- is_encoder_decoder=self.config.is_encoder_decoder)
- unfinished_sequences = unfinished_sequences.mul(
- (sum(next_tokens != i for i in eos_token_id)).long())
- # stop when each sentence is finished, or if we exceed the maximum length
- if unfinished_sequences.max() == 0 or stopping_criteria(
- input_ids, scores):
- break
- yield input_ids
- def quantize(self, bits: int, empty_init=False, **kwargs):
- if bits == 0:
- return
- from .quantization import quantize
- if self.quantized:
- logger.info('Already quantized.')
- return self
- self.quantized = True
- self.config.quantization_bit = bits
- self.transformer = quantize(
- self.transformer, bits, empty_init=empty_init, **kwargs)
- return self
- def chat(self, input: Dict) -> Dict:
- text = input['text']
- history = input['history']
- # args
- if 'max_length' in input:
- max_length = input['max_length']
- else:
- max_length = 2048
- if 'temperature' in input:
- temperature = input['temperature']
- else:
- temperature = 0.95
- if 'num_beams' in input:
- num_beams = input['num_beams']
- else:
- num_beams = 1
- if 'do_sample' in input:
- do_sample = input['do_sample']
- else:
- do_sample = True
- if type(history) == torch.Tensor:
- history = history.tolist()
- response, history = self._chat(
- self.tokenizer,
- text,
- history,
- max_length=max_length,
- temperature=temperature,
- num_beams=num_beams,
- do_sample=do_sample)
- logger.info('Generation finished.')
- return {OutputKeys.RESPONSE: response, OutputKeys.HISTORY: history}
|