| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030 |
- # Copyright (c) 2022 Zhipu.AI
- import math
- import torch
- import torch.nn.functional as F
- def fast_gelu(x):
- """Mindspore's fast gelu implementation."""
- return x / (1 + torch.exp(-1.702 * torch.abs(x))) * torch.exp(
- 0.851 * (x - torch.abs(x)))
- class MLP(torch.nn.Module):
- """MLP.
- MLP will take the input with h hidden state, project it to 4*h
- hidden dimension, perform nonlinear transformation, and project the
- state back into h hidden dimension. At the end, dropout is also
- applied.
- """
- def __init__(
- self,
- hidden_size,
- ):
- super(MLP, self).__init__()
- self.hidden_size = hidden_size
- # Project to 4h.
- self.dense_h_to_4h = torch.nn.Linear(
- self.hidden_size,
- 4 * self.hidden_size,
- )
- self.activation_func = fast_gelu
- # Project back to h.
- self.dense_4h_to_h = torch.nn.Linear(
- 4 * self.hidden_size,
- self.hidden_size,
- )
- def forward(self, hidden_states):
- # [s, b, 4hp]
- intermediate_parallel = self.dense_h_to_4h(hidden_states)
- intermediate_parallel = self.activation_func(intermediate_parallel)
- # [s, b, h]
- output = self.dense_4h_to_h(intermediate_parallel)
- return output
- class SelfAttention(torch.nn.Module):
- """self-attention layer abstract class.
- Self-attention layer takes input with size [b, s, h]
- and returns output of the same size.
- """
- def __init__(
- self,
- hidden_size,
- num_attention_heads,
- layer_number,
- fp16=True,
- attention_softmax_in_fp32=True,
- ):
- super(SelfAttention, self).__init__()
- self.hidden_size = hidden_size
- self.num_attention_heads = num_attention_heads
- self.fp16 = fp16
- self.attention_softmax_in_fp32 = attention_softmax_in_fp32
- self.layer_number = max(1, layer_number)
- assert self.hidden_size % self.num_attention_heads == 0
- self.hidden_size_per_attention_head = int(self.hidden_size
- // self.num_attention_heads)
- self.query = torch.nn.Linear(self.hidden_size, self.hidden_size)
- self.key = torch.nn.Linear(self.hidden_size, self.hidden_size)
- self.value = torch.nn.Linear(self.hidden_size, self.hidden_size)
- self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
- self.softmax = torch.nn.Softmax(dim=-1)
- self.dense = torch.nn.Linear(self.hidden_size, self.hidden_size)
- def forward(
- self,
- hidden_states,
- attention_mask,
- layer_past=None,
- get_key_value=False,
- prompt_length=None,
- context_length=None,
- ):
- # hidden_states: [sq, b, h]
- # =====================
- # Query, Key, and Value
- # =====================
- query_layer = self.query(hidden_states)
- key_layer = self.key(hidden_states)
- value_layer = self.value(hidden_states)
- new_query_layer_shape = query_layer.size()[:-1] + (
- self.num_attention_heads, self.hidden_size_per_attention_head
- ) # noqa
- query_layer = query_layer.view(*new_query_layer_shape)
- new_query_layer_shape = key_layer.size()[:-1] + (
- self.num_attention_heads, self.hidden_size_per_attention_head)
- key_layer = key_layer.view(*new_query_layer_shape)
- new_query_layer_shape = value_layer.size()[:-1] + (
- self.num_attention_heads, self.hidden_size_per_attention_head
- ) # noqa
- value_layer = value_layer.view(*new_query_layer_shape)
- # ==================================
- # Adjust key and value for inference
- # ==================================
- if layer_past is not None:
- past_key, past_value = layer_past
- key_layer = torch.cat((past_key.type_as(key_layer), key_layer),
- dim=0)
- value_layer = torch.cat(
- (past_value.type_as(value_layer), value_layer), dim=0)
- if get_key_value:
- present = (key_layer, value_layer)
- # ===================================
- # Raw attention scores. [b, np, sq, sk]
- # ===================================
- # [b, np, sq, sk]
- output_size = (query_layer.size(1), query_layer.size(2),
- query_layer.size(0), key_layer.size(0))
- # [sq, b, np, hn] -> [sq, b * np, hn]
- query_layer = query_layer.contiguous().view(
- output_size[2], output_size[0] * output_size[1], -1)
- key_layer = key_layer.contiguous().view(
- output_size[3], output_size[0] * output_size[1], -1)
- # Raw attention scores. [b * np, sq, sk]
- matmul_result = torch.matmul(
- query_layer.transpose(0, 1),
- key_layer.transpose(0, 1).transpose(1, 2)) / self.norm_factor
- # change view to [b, np, sq, sk]
- attention_scores = matmul_result.view(*output_size)
- # ==================================================
- # Update attention mask for inference. [b, np, sq, sk]
- # ==================================================
- if get_key_value:
- with torch.no_grad():
- if layer_past is not None:
- attention_mask = attention_mask[
- ...,
- attention_scores.size(3)
- - 1, :attention_scores.size(3)].unsqueeze(2)
- else:
- attention_mask = attention_mask[
- ..., :attention_scores.size(3), :attention_scores.
- size(3)]
- if context_length is not None:
- attention_mask = torch.clone(attention_mask)
- attention_mask[:, :, context_length:, :] = True
- # attention scores and attention mask [b, np, sq, sk]
- # attention_scores = attention_mask_func(attention_scores, attention_mask)
- attention_scores = attention_scores - attention_mask * 10000.0
- if self.attention_softmax_in_fp32:
- attention_probs = self.softmax(attention_scores.float()).half()
- else:
- attention_probs = self.softmax(attention_scores)
- # =========================
- # Context layer. [sq, b, hp]
- # =========================
- # value_layer -> context layer.
- # [sq, b, np, hn] --> [b, np, sq, hn]
- # context layer shape: [b, np, sq, hn]
- output_size = (value_layer.size(1), value_layer.size(2),
- query_layer.size(0), value_layer.size(3))
- # change view [sq, b * np, hn]
- value_layer = value_layer.view(
- value_layer.size(0), output_size[0] * output_size[1], -1)
- # change view [b * np, sq, sk]
- attention_probs = attention_probs.view(output_size[0] * output_size[1],
- output_size[2], -1)
- context_layer = torch.bmm(
- attention_probs,
- value_layer.unsqueeze(0).transpose(1, 2).squeeze(0))
- # change view [b, np, sq, hn]
- context_layer = context_layer.view(*output_size)
- # # [b, np, sq, hn] --> [sq, b, np, hn]
- context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
- # # [sq, b, np, hn] --> [sq, b, hp]
- new_context_layer_shape = context_layer.size()[:-2] + (
- self.hidden_size, )
- context_layer = context_layer.view(*new_context_layer_shape)
- # =================
- # Output. [sq, b, h]
- # =================
- output = self.dense(context_layer)
- if get_key_value:
- output = [output, present]
- return output
- class TopQuerySelfAttention(torch.nn.Module):
- """Top query self-attention layer abstract class.
- Self-attention layer takes input with size [b, s, h]
- and returns output of the same size.
- """
- def __init__(
- self,
- hidden_size,
- num_attention_heads,
- layer_number,
- fp16=True,
- attention_softmax_in_fp32=True,
- ):
- super(TopQuerySelfAttention, self).__init__()
- self.hidden_size = hidden_size
- self.num_attention_heads = num_attention_heads
- self.fp16 = fp16
- self.attention_softmax_in_fp32 = attention_softmax_in_fp32
- self.layer_number = max(1, layer_number)
- assert self.hidden_size % self.num_attention_heads == 0
- self.hidden_size_per_attention_head = int(self.hidden_size
- // self.num_attention_heads)
- self.query = torch.nn.Linear(self.hidden_size, self.hidden_size)
- self.key = torch.nn.Linear(self.hidden_size, self.hidden_size)
- self.value = torch.nn.Linear(self.hidden_size, self.hidden_size)
- self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
- self.softmax = torch.nn.Softmax(dim=-1)
- self.dense = torch.nn.Linear(self.hidden_size, self.hidden_size)
- def forward(
- self,
- hidden_states,
- query_hidden_state,
- attention_mask,
- layer_past=None,
- get_key_value=False,
- prompt_length=None,
- context_length=None,
- ):
- # hidden_states: [sq, b, h]
- query_layer = self.query(query_hidden_state)
- key_layer = self.key(hidden_states)
- value_layer = self.value(hidden_states)
- new_query_layer_shape = query_layer.size()[:-1] + (
- self.num_attention_heads, self.hidden_size_per_attention_head
- ) # noqa
- query_layer = query_layer.view(*new_query_layer_shape)
- new_query_layer_shape = key_layer.size()[:-1] + (
- self.num_attention_heads, self.hidden_size_per_attention_head)
- key_layer = key_layer.view(*new_query_layer_shape)
- new_query_layer_shape = value_layer.size()[:-1] + (
- self.num_attention_heads, self.hidden_size_per_attention_head
- ) # noqa
- value_layer = value_layer.view(*new_query_layer_shape)
- # ==================================
- # Adjust key and value for inference
- # ==================================
- if layer_past is not None:
- past_key, past_value = layer_past
- key_layer = torch.cat((past_key.type_as(key_layer), key_layer),
- dim=0)
- value_layer = torch.cat(
- (past_value.type_as(value_layer), value_layer), dim=0)
- if get_key_value:
- present = (key_layer, value_layer)
- # ===================================
- # Raw attention scores. [b, np, sq, sk]
- # ===================================
- # [b, np, sq, sk]
- output_size = (query_layer.size(1), query_layer.size(2),
- query_layer.size(0), key_layer.size(0))
- # [s, b, np, hn] -> [s, b * np, hn]
- query_layer = query_layer.contiguous().view(
- output_size[2], output_size[0] * output_size[1], -1)
- key_layer = key_layer.contiguous().view(
- output_size[3], output_size[0] * output_size[1], -1)
- # Raw attention scores. [b * np, sq, sk]
- matmul_result = torch.matmul(
- query_layer.transpose(0, 1),
- key_layer.transpose(0, 1).transpose(1, 2)) / self.norm_factor
- # change view to [b, np, s, s]
- attention_scores = matmul_result.view(*output_size)
- # ==================================================
- # Update attention mask for inference. [b, np, sq, sk]
- # ==================================================
- if get_key_value:
- with torch.no_grad():
- if layer_past is not None:
- attention_mask = attention_mask[
- ...,
- attention_scores.size(3)
- - 1, :attention_scores.size(3)].unsqueeze(2)
- else:
- attention_mask = attention_mask[
- ..., :attention_scores.size(3), :attention_scores.
- size(3)]
- if context_length is not None:
- attention_mask = torch.clone(attention_mask)
- attention_mask[:, :, context_length:, :] = True
- # attention scores and attention mask [b, np, sq, sk]
- # attention_scores = attention_mask_func(attention_scores, attention_mask)
- attention_scores = attention_scores - attention_mask * 10000.0
- if self.attention_softmax_in_fp32:
- attention_probs = self.softmax(attention_scores.float()).half()
- else:
- attention_probs = self.softmax(attention_scores)
- # =========================
- # Context layer. [sq, b, hp]
- # =========================
- # value_layer -> context layer.
- # [sq, b, np, hn] --> [b, np, sq, hn]
- # context layer shape: [b, np, sq, hn]
- output_size = (value_layer.size(1), value_layer.size(2),
- query_layer.size(0), value_layer.size(3))
- # change view [sq, b * np, hn]
- value_layer = value_layer.view(
- value_layer.size(0), output_size[0] * output_size[1], -1)
- # change view [b * np, sq, sk]
- attention_probs = attention_probs.view(output_size[0] * output_size[1],
- output_size[2], -1)
- # matmul: [b * np, sq, hn]
- context_layer = torch.bmm(
- attention_probs,
- value_layer.unsqueeze(0).transpose(1, 2).squeeze(0))
- # change view [b, np, sq, hn]
- context_layer = context_layer.view(*output_size)
- # [b, np, sq, hn] --> [sq, b, np, hn]
- context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
- # [sq, b, np, hn] --> [sq, b, hp]
- new_context_layer_shape = context_layer.size()[:-2] + \
- (self.hidden_size,) # noqa
- context_layer = context_layer.view(*new_context_layer_shape)
- # =================
- # Output. [sq, b, h]
- # =================
- output = self.dense(context_layer)
- if get_key_value:
- output = [output, present]
- return output
- class TransformerLayer(torch.nn.Module):
- """A single transformer layer.
- Transformore layer takes input with size [b, s, h] and returns an
- output of the same size.
- """
- def __init__(
- self,
- hidden_size,
- num_attention_heads,
- layer_number,
- layernorm_epsilon=1e-5,
- fp16=True,
- attention_softmax_in_fp32=True,
- ):
- super(TransformerLayer, self).__init__()
- self.hidden_size = hidden_size
- self.layernorm_epsilon = layernorm_epsilon
- self.layer_number = layer_number
- # Layernorm on the input data.
- self.input_layernorm = torch.nn.LayerNorm(
- hidden_size, eps=self.layernorm_epsilon)
- # Self attention.
- self.attention = SelfAttention(hidden_size, num_attention_heads,
- layer_number, fp16,
- attention_softmax_in_fp32)
- # Layernorm on the input data.
- self.post_attention_layernorm = torch.nn.LayerNorm(
- self.hidden_size, eps=self.layernorm_epsilon)
- self.mlp = MLP(self.hidden_size)
- def forward(
- self,
- hidden_states,
- attention_mask,
- layer_past=None,
- get_key_value=False,
- prompt_length=None,
- context_length=None,
- ):
- # hidden_states: [b, s, h]
- # Use FP32 for Layernorm
- # layernorm_output = self.input_layernorm(hidden_states.float()).half()
- layernorm_output = self.input_layernorm(hidden_states)
- # Self attention.
- attention_output = self.attention(
- layernorm_output,
- attention_mask,
- layer_past=layer_past,
- get_key_value=get_key_value,
- prompt_length=prompt_length,
- context_length=context_length)
- if get_key_value:
- attention_output, presents = attention_output
- # Residual connection.
- residual = hidden_states
- layernorm_input = attention_output + residual
- # Use FP32 for Layernorm
- # layernorm_output = self.post_attention_layernorm(layernorm_input.float()).half()
- layernorm_output = self.post_attention_layernorm(layernorm_input)
- mlp_output = self.mlp(layernorm_output)
- output = mlp_output + layernorm_input
- if get_key_value:
- output = [output, presents]
- return output
- class TopQueryLayer(torch.nn.Module):
- """A single top query layer.
- Top query layer takes input with size [b, s, h] and returns an
- output of the same size.
- """
- def __init__(
- self,
- hidden_size,
- num_attention_heads,
- layer_number,
- layernorm_epsilon=1e-5,
- ):
- super(TopQueryLayer, self).__init__()
- self.hidden_size = hidden_size
- self.num_attention_heads = num_attention_heads
- self.layernorm_epsilon = layernorm_epsilon
- self.layer_number = layer_number
- # Use FP32 for Layernorm
- self.input_layernorm = torch.nn.LayerNorm(
- self.hidden_size, eps=self.layernorm_epsilon)
- # Self attention.
- self.attention = TopQuerySelfAttention(self.hidden_size,
- self.num_attention_heads,
- self.layer_number)
- # Layernorm on the input data.
- self.post_attention_layernorm = torch.nn.LayerNorm(
- self.hidden_size, eps=self.layernorm_epsilon)
- # MLP
- self.mlp = MLP(self.hidden_size)
- def forward(
- self,
- hidden_states,
- query_hidden_state,
- attention_mask,
- layer_past=None,
- get_key_value=False,
- prompt_length=None,
- context_length=None,
- ):
- # hidden_states: [b, s, h]
- assert query_hidden_state != None # noqa
- # Use FP32 for Layernorm
- # layernorm_output = self.input_layernorm(hidden_states.float()).half()
- layernorm_output = self.input_layernorm(hidden_states)
- # Self attention.
- attention_output = self.attention(
- layernorm_output,
- query_hidden_state,
- attention_mask,
- layer_past=layer_past,
- get_key_value=get_key_value,
- prompt_length=prompt_length,
- context_length=context_length)
- if get_key_value:
- attention_output, presents = attention_output
- # Residual connection.
- residual = hidden_states
- layernorm_input = attention_output + residual
- # Use FP32 for Layernorm
- # layernorm_output = self.post_attention_layernorm(layernorm_input.float()).half()
- layernorm_output = self.post_attention_layernorm(layernorm_input)
- # MLP.
- mlp_output = self.mlp(layernorm_output)
- # Second residual connection.
- residual = layernorm_input
- output = mlp_output + residual
- if get_key_value:
- output = [output, presents]
- return output
- class Transformer(torch.nn.Module):
- """Transformer class."""
- def __init__(
- self,
- hidden_size,
- num_attention_heads,
- num_layers,
- layernorm_epsilon=1e-5,
- ):
- super(Transformer, self).__init__()
- self.hidden_size = hidden_size
- self.num_attention_heads = num_attention_heads
- self.layernorm_epsilon = layernorm_epsilon
- # Number of layers:
- self.num_layers = num_layers
- self.num_unique_layers = None
- #################
- assert self.num_unique_layers is None
- #################
- if self.num_unique_layers is None:
- self.num_unique_layers = self.num_layers
- assert self.num_layers % self.num_unique_layers == 0, \
- 'number of layers should be divisible by number of unique layers'
- # Transformer layers.
- def build_layer(layer_number):
- return TransformerLayer(self.hidden_size, self.num_attention_heads,
- layer_number)
- self.layers = torch.nn.ModuleList(
- [build_layer(i + 1) for i in range(self.num_unique_layers)])
- self.topQueryLayer = TopQueryLayer(self.hidden_size,
- self.num_attention_heads,
- self.num_unique_layers)
- self.final_layernorm = torch.nn.LayerNorm(
- self.hidden_size, eps=self.layernorm_epsilon)
- def _get_layer_index(self, layer_number):
- return layer_number % self.num_unique_layers
- def _get_layer(self, layer_number):
- return self.layers[self._get_layer_index(layer_number)]
- def forward(
- self,
- hidden_states,
- query_hidden_state,
- attention_mask,
- layer_past=None,
- get_key_value=False,
- prompt_length=None,
- context_length=None,
- ):
- # data format change to avoid explicit transposes : [b s h] --> [s b h]
- hidden_states = hidden_states.transpose(0, 1).contiguous()
- query_hidden_state = query_hidden_state.transpose(0, 1).contiguous()
- if get_key_value:
- presents = []
- for index in range(self.num_layers):
- layer = self._get_layer(index)
- past = None
- if layer_past is not None:
- past = layer_past[index]
- hidden_states = layer(
- hidden_states,
- attention_mask,
- layer_past=past,
- get_key_value=get_key_value,
- prompt_length=prompt_length,
- context_length=context_length)
- if get_key_value:
- hidden_states, present = hidden_states
- presents.append(present)
- # Use FP32 for Layernorm
- # hidden_states_ = self.final_layernorm(hidden_states.float()).half()
- hidden_states_ = self.final_layernorm(hidden_states)
- #################################
- # top query layer
- #################################
- past = None
- if layer_past is not None:
- past = layer_past[self.num_layers]
- hidden_states = self.topQueryLayer(
- hidden_states_,
- query_hidden_state,
- attention_mask,
- layer_past=past,
- get_key_value=get_key_value,
- prompt_length=prompt_length,
- context_length=context_length)
- if get_key_value:
- hidden_states, present = hidden_states
- presents.append(present)
- # reverting data format change [s b h] --> [b s h]
- output = hidden_states.transpose(0, 1).contiguous()
- if get_key_value:
- output = [output, presents]
- return output
- def state_dict_for_save_checkpoint(self,
- destination=None,
- prefix='',
- keep_vars=False):
- return self.state_dict(destination, prefix, keep_vars)
- class Embedding(torch.nn.Module):
- """Language model embeddings.
- Arguments:
- hidden_size: hidden size
- vocab_size: vocabulary size
- max_sequence_length: maximum size of sequence. This
- is used for positional embedding
- """
- def __init__(
- self,
- hidden_size,
- vocab_size,
- max_sequence_length,
- ):
- super(Embedding, self).__init__()
- self.hidden_size = hidden_size
- self.vocab_size = vocab_size
- self.max_sequence_length = max_sequence_length
- # Word embeddings.
- self.word_embeddings = torch.nn.Embedding(self.vocab_size,
- self.hidden_size)
- self._word_embeddings_key = 'word_embeddings'
- # Position embedding.
- self.position_embeddings = torch.nn.Embedding(self.max_sequence_length,
- self.hidden_size)
- self.position_embeddings = self.position_embeddings.half()
- self._position_embeddings_key = 'position_embeddings'
- def forward(self, input_ids, position_ids):
- # Embeddings.
- words_embeddings = self.word_embeddings(input_ids)
- position_embeddings = self.position_embeddings(position_ids)
- embeddings = words_embeddings + position_embeddings
- return embeddings
- def state_dict_for_save_checkpoint(self,
- destination=None,
- prefix='',
- keep_vars=False):
- """For easy load."""
- state_dict_ = {}
- state_dict_[self._word_embeddings_key] \
- = self.word_embeddings.state_dict(destination, prefix, keep_vars)
- state_dict_[self._position_embeddings_key] \
- = self.position_embeddings.state_dict(
- destination, prefix, keep_vars)
- return state_dict_
- def load_state_dict(self, state_dict, strict=True):
- """Customized load."""
- # Word embedding.
- if self._word_embeddings_key in state_dict:
- state_dict_ = state_dict[self._word_embeddings_key]
- else:
- # for backward compatibility.
- state_dict_ = {}
- for key in state_dict.keys():
- if 'word_embeddings' in key:
- state_dict_[key.split('word_embeddings.')[1]] \
- = state_dict[key]
- state_dict_['weight'] = state_dict_['weight'][:self.vocab_size]
- self.word_embeddings.load_state_dict(state_dict_, strict=strict)
- # Position embedding.
- if self._position_embeddings_key in state_dict:
- state_dict_ = state_dict[self._position_embeddings_key]
- else:
- # for backward compatibility.
- state_dict_ = {}
- for key in state_dict.keys():
- if 'position_embeddings' in key:
- state_dict_[key.split('position_embeddings.')[1]] \
- = state_dict[key]
- self.position_embeddings.load_state_dict(state_dict_, strict=strict)
- class QueryEmbedding(torch.nn.Module):
- """Language model embeddings.
- Arguments:
- hidden_size: hidden size
- vocab_size: vocabulary size
- max_sequence_length: maximum size of sequence. This
- is used for positional embedding
- """
- def __init__(
- self,
- hidden_size,
- vocab_size,
- max_sequence_length,
- ):
- super(QueryEmbedding, self).__init__()
- self.hidden_size = hidden_size
- self.vocab_size = vocab_size
- self.max_sequence_length = max_sequence_length
- # Top query position embedding (serial).
- self.top_query_embeddings = torch.nn.Embedding(
- self.max_sequence_length, self.hidden_size)
- self.top_query_embeddings = self.top_query_embeddings.half()
- self._top_query_embeddings_key = 'top_query_embeddings'
- def forward(self, position_ids):
- # Embeddings.
- embeddings = self.top_query_embeddings(position_ids)
- return embeddings
- def state_dict_for_save_checkpoint(self,
- destination=None,
- prefix='',
- keep_vars=False):
- """For easy load."""
- state_dict_ = {}
- state_dict_[self._top_query_embeddings_key] \
- = self.top_query_embeddings.state_dict(
- destination, prefix, keep_vars)
- return state_dict_
- def load_state_dict(self, state_dict, strict=True):
- """Customized load."""
- # Position embedding.
- if self._top_query_embeddings_key in state_dict:
- state_dict_ = state_dict[self._top_query_embeddings_key]
- else:
- # for backward compatibility.
- state_dict_ = {}
- for key in state_dict.keys():
- if 'top_query_embeddings' in key:
- state_dict_[key.split('top_query_embeddings.')[1]] \
- = state_dict[key]
- self.top_query_embeddings.load_state_dict(state_dict_, strict=strict)
- class TransformerLanguageModel(torch.nn.Module):
- """Transformer language model.
- Arguments:
- transformer_hparams: transformer hyperparameters
- attention_mask_func: a function that takes `unmaksed-attention-scores`
- with size [b, np, s, s] and an `attention-mask` and will apply
- the masking. The function should return a masked score of the
- same size [b, np, s, s].
- masked-attention-scores = attention_mask_func(
- unmaksed-attention-scores, attention-mask)
- vocab_size: vocabulary size
- max_sequence_length: maximum size of sequence. This
- is used for positional embedding
- """
- def __init__(
- self,
- hidden_size,
- num_layers,
- num_attention_heads,
- padded_vocab_size,
- max_position_embeddings,
- ):
- super(TransformerLanguageModel, self).__init__()
- self.hidden_size = hidden_size
- self.num_layers = num_layers
- self.num_attention_heads = num_attention_heads
- self.padded_vocab_size = padded_vocab_size
- self.max_position_embeddings = max_position_embeddings
- # Embeddings
- self.embedding = Embedding(self.hidden_size, self.padded_vocab_size,
- self.max_position_embeddings)
- self._embedding_key = 'embedding'
- # Query embeddings
- self.topQueryEmbedding = QueryEmbedding(self.hidden_size,
- self.padded_vocab_size,
- self.max_position_embeddings)
- self._topQueryEmbedding_key = 'topQueryEmbedding'
- # Transformer
- self.transformer = Transformer(self.hidden_size,
- self.num_attention_heads,
- self.num_layers)
- self._transformer_key = 'transformer'
- def forward(
- self,
- input_ids,
- position_ids,
- attention_mask,
- layer_past=None,
- get_key_value=False,
- prompt_length=None,
- context_length=None,
- ):
- # Embeddings.
- embedding_output = self.embedding(input_ids, position_ids)
- query_position_ids = position_ids
- queryEmbedding_out = self.topQueryEmbedding(query_position_ids)
- # Transformer.
- transformer_output = self.transformer(
- embedding_output,
- queryEmbedding_out,
- attention_mask,
- layer_past=layer_past,
- get_key_value=get_key_value,
- prompt_length=prompt_length,
- context_length=context_length)
- return transformer_output
- def state_dict_for_save_checkpoint(self,
- destination=None,
- prefix='',
- keep_vars=False):
- """For easy load."""
- state_dict_ = {}
- state_dict_[self._embedding_key] \
- = self.embedding.state_dict_for_save_checkpoint(
- destination, prefix, keep_vars)
- state_dict_[self._topQueryEmbedding_key] \
- = self.topQueryEmbedding.state_dict_for_save_checkpoint(
- destination, prefix, keep_vars)
- state_dict_[self._transformer_key] \
- = self.transformer.state_dict_for_save_checkpoint(
- destination, prefix, keep_vars)
- return state_dict_
- def load_state_dict(self, state_dict, strict=True):
- """Customized load."""
- # Embedding.
- if self._embedding_key in state_dict:
- state_dict_ = state_dict[self._embedding_key]
- else:
- # for backward compatibility.
- state_dict_ = {}
- for key in state_dict.keys():
- if '_embeddings' in key:
- state_dict_[key] = state_dict[key]
- self.embedding.load_state_dict(state_dict_, strict=strict)
- if self._topQueryEmbedding_key in state_dict:
- state_dict_ = state_dict[self._topQueryEmbedding_key]
- else:
- # for backward compatibility.
- state_dict_ = {}
- for key in state_dict.keys():
- if '_embeddings' in key:
- state_dict_[key] = state_dict[key]
- self.topQueryEmbedding.load_state_dict(state_dict_, strict=strict)
- # Transformer.
- if self._transformer_key in state_dict:
- state_dict_ = state_dict[self._transformer_key]
- else:
- # for backward compatibility.
- state_dict_ = {}
- for key in state_dict.keys():
- if 'transformer.' in key:
- state_dict_[key.split('transformer.')[1]] = state_dict[key]
- self.transformer.load_state_dict(state_dict_, strict=strict)
- class CodeGeeXModel(torch.nn.Module):
- """CodeGeeX: A Multilingual Code Generation Model."""
- def __init__(
- self,
- hidden_size,
- num_layers,
- num_attention_heads,
- padded_vocab_size,
- max_position_embeddings,
- ):
- super(CodeGeeXModel, self).__init__()
- self.language_model = TransformerLanguageModel(
- hidden_size, num_layers, num_attention_heads, padded_vocab_size,
- max_position_embeddings)
- self._language_model_key = 'language_model'
- def forward(
- self,
- input_ids,
- position_ids,
- attention_mask,
- layer_past=None,
- get_key_value=False,
- prompt_length=None,
- context_length=None,
- ):
- # Language model.
- lm_output = self.language_model(
- input_ids,
- position_ids,
- attention_mask,
- layer_past=layer_past,
- get_key_value=get_key_value,
- prompt_length=prompt_length,
- context_length=context_length)
- if get_key_value:
- lm_output, presents = lm_output
- output = F.linear(
- lm_output,
- self.language_model.embedding.word_embeddings.weight.half())
- if get_key_value:
- output = [output, presents]
- return output
- def state_dict_for_save_checkpoint(self,
- destination=None,
- prefix='',
- keep_vars=False):
- state_dict_ = {}
- state_dict_[self._language_model_key] \
- = self.language_model.state_dict_for_save_checkpoint(
- destination, prefix, keep_vars)
- return state_dict_
- def load_state_dict(self, state_dict, strict=True):
- """Customized load."""
- if self._language_model_key in state_dict:
- state_dict = state_dict[self._language_model_key]
- self.language_model.load_state_dict(state_dict, strict=strict)
|