| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009 |
- # Copyright 2021-2022 The Alibaba DAMO NLP Team Authors.
- # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
- # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from __future__ import (absolute_import, division, print_function,
- unicode_literals)
- import logging
- import math
- import torch
- import torch.nn.functional as F
- from megatron_util import mpu
- from torch import nn
- from modelscope.utils.nlp.distributed import (normal_init_method,
- scaled_init_method)
- from .configuration import PlugNLGConfig, PlugNLUConfig
- logger = logging.getLogger(__name__)
- def gelu(x):
- """Implementation of the gelu activation function.
- For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
- 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
- """
- return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
- def swish(x):
- return x * torch.sigmoid(x)
- ACT2FN = {'gelu': gelu, 'relu': torch.nn.functional.relu, 'swish': swish}
- class BertLayerNorm(nn.Module):
- def __init__(self, hidden_size, eps=1e-12):
- """Construct a layernorm module in the TF style (epsilon inside the square root).
- """
- super(BertLayerNorm, self).__init__()
- self.weight = nn.Parameter(torch.ones(hidden_size))
- self.bias = nn.Parameter(torch.zeros(hidden_size))
- self.variance_epsilon = eps
- def forward(self, x):
- u = x.mean(-1, keepdim=True)
- s = (x - u).pow(2).mean(-1, keepdim=True)
- x = (x - u) / torch.sqrt(s + self.variance_epsilon)
- return self.weight * x + self.bias
- class BertEmbeddings(nn.Module):
- """Construct the embeddings from word, position and token_type embeddings.
- """
- def __init__(self, config):
- super(BertEmbeddings, self).__init__()
- self.word_embeddings = mpu.VocabParallelEmbedding(
- config.vocab_size,
- config.hidden_size,
- init_method=normal_init_method(
- mean=0.0, std=config.initializer_range))
- self.position_embeddings = nn.Embedding(config.max_position_embeddings,
- config.hidden_size)
- self.token_type_embeddings = nn.Embedding(config.type_vocab_size,
- config.hidden_size)
- # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
- # any TensorFlow checkpoint file
- self.fp32_layernorm = config.fp32_layernorm
- self.fp32_embedding = config.fp32_embedding
- self.fp32_tokentypes = config.fp32_tokentypes
- self.LayerNorm = BertLayerNorm(
- config.hidden_size, eps=config.layernorm_epsilon)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- def forward(self, input_ids, token_type_ids=None, position_ids=None):
- seq_length = input_ids.size(1)
- if position_ids is None:
- position_ids = torch.arange(
- seq_length, dtype=torch.long, device=input_ids.device)
- position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
- if token_type_ids is None:
- token_type_ids = torch.zeros_like(input_ids)
- words_embeddings = self.word_embeddings(input_ids)
- position_embeddings = self.position_embeddings(position_ids)
- token_type_embeddings = self.token_type_embeddings(token_type_ids)
- if not self.fp32_tokentypes:
- embeddings = words_embeddings + position_embeddings + token_type_embeddings
- if self.fp32_embedding and not self.fp32_layernorm:
- embeddings = embeddings.half()
- previous_type = embeddings.type()
- if self.fp32_layernorm:
- embeddings = embeddings.float()
- embeddings = self.LayerNorm(embeddings)
- if self.fp32_layernorm:
- if self.fp32_embedding:
- embeddings = embeddings.half()
- else:
- embeddings = embeddings.type(previous_type)
- else:
- embeddings = words_embeddings.float() + position_embeddings.float(
- ) + token_type_embeddings.float()
- if self.fp32_tokentypes and not self.fp32_layernorm:
- embeddings = embeddings.half()
- previous_type = embeddings.type()
- if self.fp32_layernorm:
- embeddings = embeddings.float()
- embeddings = self.LayerNorm(embeddings)
- if self.fp32_layernorm:
- if self.fp32_tokentypes:
- embeddings = embeddings.half()
- else:
- embeddings = embeddings.type(previous_type)
- embeddings = self.dropout(embeddings)
- return embeddings
- class BertSelfOutput(nn.Module):
- def __init__(self, config):
- super(BertSelfOutput, self).__init__()
- if hasattr(config, 'deep_init') and config.deep_init:
- init_method = scaled_init_method(
- mean=0.0,
- std=config.initializer_range,
- num_layers=config.num_hidden_layers)
- else:
- init_method = normal_init_method(
- mean=0.0, std=config.initializer_range)
- self.dense = mpu.RowParallelLinear(
- input_size=config.hidden_size,
- output_size=config.hidden_size,
- bias=True,
- input_is_parallel=True,
- stride=1,
- init_method=init_method)
- self.fp32_layernorm = config.fp32_layernorm
- if not config.pre_ln:
- self.LayerNorm = BertLayerNorm(
- config.hidden_size, eps=config.layernorm_epsilon)
- else:
- self.LayerNorm = None
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- def forward(
- self,
- hidden_states,
- input_tensor,
- ):
- hidden_states = self.dense(hidden_states)
- hidden_states = self.dropout(hidden_states)
- ln_input = hidden_states + input_tensor
- if self.LayerNorm is not None:
- previous_type = ln_input.type()
- if self.fp32_layernorm:
- ln_input = ln_input.float()
- hidden_states = self.LayerNorm(ln_input)
- if self.fp32_layernorm:
- hidden_states = hidden_states.type(previous_type)
- else:
- hidden_states = ln_input
- return hidden_states
- class BertAttention(nn.Module):
- def __init__(self, config):
- super(BertAttention, self).__init__()
- self.fp32_layernorm = config.fp32_layernorm
- if config.pre_ln:
- self.LayerNorm = BertLayerNorm(
- config.hidden_size, eps=config.layernorm_epsilon)
- else:
- self.LayerNorm = None
- self.self = mpu.BertParallelSelfAttention(
- hidden_size=config.hidden_size,
- num_attention_heads=config.num_attention_heads,
- dropout_prob=config.attention_probs_dropout_prob,
- output_parallel=True,
- init_method=normal_init_method(
- mean=0.0, std=config.initializer_range),
- separate=config.attn_separate)
- self.output = BertSelfOutput(config)
- def forward(
- self,
- input_tensor,
- attention_mask,
- ):
- if self.LayerNorm is not None:
- ln_input = input_tensor
- previous_type = input_tensor.type()
- if self.fp32_layernorm:
- ln_input = input_tensor.float()
- ln_output = self.LayerNorm(ln_input)
- if self.fp32_layernorm:
- ln_output = ln_output.type(previous_type)
- self_output = self.self(
- ln_output,
- attention_mask,
- )
- else:
- self_output = self.self(
- input_tensor,
- attention_mask,
- )
- attention_output = self.output(
- self_output,
- input_tensor,
- )
- return attention_output
- class BertIntermediate(nn.Module):
- def __init__(self, config):
- super(BertIntermediate, self).__init__()
- self.dense = mpu.ColumnParallelLinear(
- input_size=config.hidden_size,
- output_size=config.intermediate_size,
- bias=True,
- gather_output=False,
- stride=1,
- init_method=normal_init_method(
- mean=0.0, std=config.initializer_range))
- self.intermediate_act_fn = ACT2FN[config.hidden_act] \
- if isinstance(config.hidden_act, str) else config.hidden_act
- def forward(
- self,
- hidden_states,
- ):
- hidden_states = self.dense(hidden_states)
- hidden_states = self.intermediate_act_fn(hidden_states)
- return hidden_states
- class BertOutput(nn.Module):
- def __init__(self, config):
- super(BertOutput, self).__init__()
- if hasattr(config, 'deep_init') and config.deep_init:
- init_method = scaled_init_method(
- mean=0.0,
- std=config.initializer_range,
- num_layers=config.num_hidden_layers)
- else:
- init_method = normal_init_method(
- mean=0.0, std=config.initializer_range)
- self.dense = mpu.RowParallelLinear(
- input_size=config.intermediate_size,
- output_size=config.hidden_size,
- bias=True,
- input_is_parallel=True,
- stride=1,
- init_method=init_method)
- self.fp32_layernorm = config.fp32_layernorm
- if not config.pre_ln:
- self.LayerNorm = BertLayerNorm(
- config.hidden_size, eps=config.layernorm_epsilon)
- else:
- self.LayerNorm = None
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- def forward(
- self,
- hidden_states,
- input_tensor,
- ):
- hidden_states = self.dense(hidden_states)
- hidden_states = self.dropout(hidden_states)
- ln_input = hidden_states + input_tensor
- if self.LayerNorm is not None:
- previous_type = ln_input.type()
- if self.fp32_layernorm:
- ln_input = ln_input.float()
- hidden_states = self.LayerNorm(ln_input)
- if self.fp32_layernorm:
- hidden_states = hidden_states.type(previous_type)
- else:
- hidden_states = ln_input
- return hidden_states
- class BertLayer(nn.Module):
- def __init__(self, config):
- super(BertLayer, self).__init__()
- self.attention = BertAttention(config)
- self.intermediate = BertIntermediate(config)
- self.output = BertOutput(config)
- self.fp32_layernorm = config.fp32_layernorm
- if config.pre_ln:
- self.LayerNorm = BertLayerNorm(
- config.hidden_size, eps=config.layernorm_epsilon)
- else:
- self.LayerNorm = None
- def forward(self, hidden_states, attention_mask):
- attention_output = self.attention(hidden_states, attention_mask)
- if self.LayerNorm is not None:
- ln_input = attention_output
- previous_type = attention_output.type()
- if self.fp32_layernorm:
- ln_input = attention_output.float()
- ln_output = self.LayerNorm(ln_input)
- if self.fp32_layernorm:
- ln_output = ln_output.type(previous_type)
- intermediate_output = self.intermediate(ln_output)
- else:
- intermediate_output = self.intermediate(attention_output)
- layer_output = self.output(intermediate_output, attention_output)
- return layer_output
- class BertEncoder(nn.Module):
- def __init__(self, config):
- super(BertEncoder, self).__init__()
- self.layer = nn.ModuleList(
- [BertLayer(config) for _ in range(config.num_hidden_layers)])
- self.fp32_layernorm = config.fp32_layernorm
- if config.pre_ln:
- self.LayerNorm = BertLayerNorm(
- config.hidden_size, eps=config.layernorm_epsilon)
- else:
- self.LayerNorm = None
- def forward(
- self,
- hidden_states,
- attention_mask,
- output_all_encoded_layers=True,
- checkpoint_activations=False,
- detach_index=-1,
- ):
- all_encoder_layers = []
- def custom(start, end):
- def custom_forward(*inputs):
- layers = self.layer[start:end]
- x_ = inputs[0]
- for layer in layers:
- x_ = layer(x_, inputs[1])
- return x_
- return custom_forward
- if checkpoint_activations:
- layer_idx = 0
- num_layers = len(self.layer)
- chunk_length = 1
- while layer_idx < num_layers:
- hidden_states = mpu.checkpoint(
- custom(layer_idx, layer_idx + chunk_length), hidden_states,
- attention_mask * 1)
- if detach_index == layer_idx:
- hidden_states.detach_()
- layer_idx += chunk_length
- # decoder layers
- else:
- for i, layer_module in enumerate(self.layer):
- hidden_states = layer_module(hidden_states, attention_mask)
- if detach_index == i:
- hidden_states.detach_()
- if i == len(self.layer) - 1 and self.LayerNorm is not None:
- previous_type = hidden_states.type()
- if self.fp32_layernorm:
- hidden_states = hidden_states.float()
- hidden_states = self.LayerNorm(hidden_states)
- if self.fp32_layernorm:
- hidden_states = hidden_states.type(previous_type)
- if output_all_encoded_layers:
- all_encoder_layers.append(hidden_states)
- if not output_all_encoded_layers or checkpoint_activations:
- if self.LayerNorm is not None:
- previous_type = hidden_states.type()
- if self.fp32_layernorm:
- hidden_states = hidden_states.float()
- hidden_states = self.LayerNorm(hidden_states)
- if self.fp32_layernorm:
- hidden_states = hidden_states.type(previous_type)
- all_encoder_layers.append(hidden_states)
- return all_encoder_layers
- class BertPooler(nn.Module):
- def __init__(self, config):
- super(BertPooler, self).__init__()
- self.dense = nn.Linear(config.hidden_size, config.hidden_size)
- self.activation = nn.Tanh()
- def forward(self, hidden_states):
- # We "pool" the model by simply taking the hidden state corresponding
- # to the first token.
- first_token_tensor = hidden_states[:, 0]
- pooled_output = self.dense(first_token_tensor)
- pooled_output = self.activation(pooled_output)
- return pooled_output
- class BertPredictionHeadTransform(nn.Module):
- def __init__(self, config):
- super(BertPredictionHeadTransform, self).__init__()
- self.dense = nn.Linear(config.hidden_size, config.hidden_size)
- self.transform_act_fn = ACT2FN[config.hidden_act] \
- if isinstance(config.hidden_act, str) else config.hidden_act
- self.LayerNorm = BertLayerNorm(
- config.hidden_size, eps=config.layernorm_epsilon)
- self.fp32_layernorm = config.fp32_layernorm
- def forward(self, hidden_states):
- hidden_states = self.dense(hidden_states)
- hidden_states = self.transform_act_fn(hidden_states)
- previous_type = hidden_states.type()
- if self.fp32_layernorm:
- hidden_states = hidden_states.float()
- hidden_states = self.LayerNorm(hidden_states)
- if self.fp32_layernorm:
- hidden_states = hidden_states.type(previous_type)
- return hidden_states
- class BertLMPredictionHead(nn.Module):
- def __init__(self, config, bert_model_embedding_weights):
- super(BertLMPredictionHead, self).__init__()
- self.transform = BertPredictionHeadTransform(config)
- # The output weights are the same as the input embeddings, but there is
- # an output-only bias for each token.
- self.decoder_weight = bert_model_embedding_weights
- self.bias = nn.Parameter(
- torch.zeros(bert_model_embedding_weights.size(0)))
- self.bias.model_parallel = True
- self.fp32_embedding = config.fp32_embedding
- self.fp32_layernorm = config.fp32_layernorm
- def convert_to_type(tensor):
- if self.fp32_embedding:
- return tensor.half()
- else:
- return tensor
- self.type_converter = convert_to_type
- self.converted = False
- def forward(self, hidden_states):
- if not self.converted:
- self.converted = True
- if self.fp32_embedding:
- self.transform.half()
- if self.fp32_layernorm:
- self.transform.LayerNorm.float()
- hidden_states = self.transform(self.type_converter(hidden_states))
- hidden_states = mpu.copy_to_model_parallel_region(hidden_states)
- hidden_states = F.linear(
- self.type_converter(hidden_states),
- self.type_converter(self.decoder_weight),
- self.type_converter(self.bias))
- return hidden_states
- class BertPreTrainingHeads(nn.Module):
- def __init__(self, config, bert_model_embedding_weights):
- super(BertPreTrainingHeads, self).__init__()
- self.predictions = BertLMPredictionHead(config,
- bert_model_embedding_weights)
- self.seq_relationship = nn.Linear(config.hidden_size, 3)
- def forward(self, sequence_output, pooled_output):
- prediction_scores = self.predictions(sequence_output)
- for p in self.seq_relationship.parameters():
- if p is None:
- continue
- pooled_output = pooled_output.type_as(p)
- seq_relationship_score = self.seq_relationship(pooled_output)
- return prediction_scores, seq_relationship_score
- class PreTrainedBertModel(nn.Module):
- """ An abstract class to handle weights initialization and
- a simple interface for downloading and loading pretrained models.
- """
- def __init__(self, config, *inputs, **kwargs):
- super(PreTrainedBertModel, self).__init__()
- if not isinstance(config, PlugNLUConfig) and not isinstance(
- config, PlugNLGConfig):
- raise ValueError(
- 'Parameter config in `{}(config)` should be an instance of class `BertConfig`. '
- 'To create a model from a Google pretrained model use '
- '`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`'.format(
- self.__class__.__name__, self.__class__.__name__))
- self.config = config
- def init_bert_weights(self, module):
- """ Initialize the weights.
- """
- if isinstance(module, (nn.Linear, nn.Embedding)):
- # Slightly different from the TF version which uses truncated_normal for initialization
- # cf https://github.com/pytorch/pytorch/pull/5617
- module.weight.data.normal_(
- mean=0.0, std=self.config.initializer_range)
- elif isinstance(module, BertLayerNorm):
- module.bias.data.zero_()
- module.weight.data.fill_(1.0)
- if isinstance(module, nn.Linear) and module.bias is not None:
- module.bias.data.zero_()
- class BertModel(PreTrainedBertModel):
- """BERT model ("Bidirectional Embedding Representations from a Transformer").
- Params:
- config: a BertConfig class instance with the configuration to build a new model
- Inputs:
- `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
- with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
- `extract_features.py`, `run_classifier.py` and `run_squad.py`)
- `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
- types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
- a `sentence B` token (see BERT paper for more details).
- `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
- selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
- input sequence length in the current batch. It's the mask that we typically use for attention when
- a batch has varying length sentences.
- `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as
- described below. Default: `True`.
- Outputs: Tuple of (encoded_layers, pooled_output)
- `encoded_layers`: controlled by `output_all_encoded_layers` argument:
- - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end
- of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each
- encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
- - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
- to the last attention block of shape [batch_size, sequence_length, hidden_size],
- `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
- classifier pretrained on top of the hidden state associated to the first character of the
- input (`CLF`) to train on the Next-Sentence task (see BERT's paper).
- Examples:
- >>> # Already been converted into WordPiece token ids
- >>> input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
- >>> input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
- >>> token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
- >>> config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
- >>> num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
- >>> model = modeling.BertModel(config=config)
- >>> all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
- """
- def __init__(self, config):
- super(BertModel, self).__init__(config)
- self.embeddings = BertEmbeddings(config)
- self.encoder = BertEncoder(config)
- self.pooler = BertPooler(config)
- self.apply(self.init_bert_weights)
- def forward(
- self,
- input_ids,
- token_type_ids=None,
- attention_mask=None,
- output_all_encoded_layers=True,
- checkpoint_activations=False,
- detach_index=-1,
- ):
- if attention_mask is None:
- attention_mask = torch.ones_like(input_ids)
- if token_type_ids is None:
- token_type_ids = torch.zeros_like(input_ids)
- # We create a 3D attention mask from a 2D tensor mask.
- # Sizes are [batch_size, 1, 1, to_seq_length]
- # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
- # this attention mask is more simple than the triangular masking of causal attention
- # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
- extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
- # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
- # masked positions, this operation will create a tensor which is 0.0 for
- # positions we want to attend and -10000.0 for masked positions.
- # Since we are adding it to the raw scores before the softmax, this is
- # effectively the same as removing these entirely.
- extended_attention_mask = extended_attention_mask.to(
- dtype=next(self.encoder.parameters()).dtype) # fp16 compatibility
- extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
- embedding_output = self.embeddings(input_ids, token_type_ids)
- encoded_layers = self.encoder(
- embedding_output,
- extended_attention_mask,
- output_all_encoded_layers=output_all_encoded_layers,
- checkpoint_activations=checkpoint_activations,
- detach_index=detach_index)
- sequence_output = encoded_layers[-1]
- for p in self.pooler.parameters():
- if p is None:
- continue
- sequence_output = sequence_output.type_as(p)
- break
- pooled_output = sequence_output[:, 0]
- if not output_all_encoded_layers or checkpoint_activations:
- encoded_layers = encoded_layers[-1]
- return encoded_layers, pooled_output
- class DecodeLayer(nn.Module):
- def __init__(self, config):
- super(DecodeLayer, self).__init__()
- init_method = normal_init_method(
- mean=0.0, std=config.initializer_range)
- output_layer_init_method = scaled_init_method(
- mean=0.0,
- std=config.initializer_range,
- num_layers=config.num_hidden_layers)
- self.attention = mpu.GPT2ParallelSelfAttention(
- hidden_size=config.hidden_size,
- num_attention_heads=config.num_attention_heads,
- attention_dropout_prob=config.attention_probs_dropout_prob,
- output_dropout_prob=config.hidden_dropout_prob,
- init_method=init_method,
- output_layer_init_method=output_layer_init_method,
- )
- self.cross_attention = mpu.PalmParallelCrossAttention(
- hidden_size=config.hidden_size,
- num_attention_heads=config.num_attention_heads,
- attention_dropout_prob=config.attention_probs_dropout_prob,
- output_dropout_prob=config.hidden_dropout_prob,
- init_method=init_method,
- attn_separate=False,
- output_layer_init_method=output_layer_init_method,
- )
- self.input_layernorm = BertLayerNorm(
- config.hidden_size, eps=config.layernorm_epsilon)
- self.post_attention_layernorm = BertLayerNorm(
- config.hidden_size, eps=config.layernorm_epsilon)
- self.post_cross_attention_layernorm = BertLayerNorm(
- config.hidden_size, eps=config.layernorm_epsilon)
- self.intermediate = mpu.ColumnParallelLinear(
- config.hidden_size,
- config.intermediate_size,
- gather_output=False,
- init_method=init_method,
- )
- self.intermediate_act_fn = ACT2FN[config.hidden_act] \
- if isinstance(config.hidden_act, str) else config.hidden_act
- self.output = mpu.RowParallelLinear(
- config.intermediate_size,
- config.hidden_size,
- input_is_parallel=True,
- init_method=output_layer_init_method,
- )
- self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
- self.fp32_layernorm = config.fp32_layernorm
- def convert_to_type(tensor):
- if self.fp32_layernorm:
- return tensor.float()
- else:
- return tensor
- self.type_converter = convert_to_type
- # def forward(self, hidden_states, enc_attn_mask, dec_attn_mask):
- def forward(self,
- hidden_states,
- enc_hidden_states,
- enc_attn_mask,
- dec_attn_mask,
- is_infer=False):
- residual = hidden_states
- previous_type = hidden_states.type()
- hidden_states = self.input_layernorm(
- self.type_converter(hidden_states))
- if self.fp32_layernorm:
- hidden_states = hidden_states.type(previous_type)
- hidden_states = self.attention(
- hidden_states, dec_attn_mask, is_infer=is_infer)
- hidden_states = residual + hidden_states
- residual = hidden_states
- hidden_states = self.post_attention_layernorm(
- self.type_converter(hidden_states))
- if self.fp32_layernorm:
- hidden_states = hidden_states.type(previous_type)
- hidden_states = self.cross_attention(hidden_states, enc_hidden_states,
- enc_attn_mask)
- hidden_states = residual + hidden_states
- residual = hidden_states
- hidden_states = self.post_cross_attention_layernorm(
- self.type_converter(hidden_states))
- if self.fp32_layernorm:
- hidden_states = hidden_states.type(previous_type)
- hidden_states = self.intermediate(hidden_states)
- hidden_states = self.intermediate_act_fn(hidden_states)
- hidden_states = self.output(hidden_states)
- hidden_states = self.dropout(hidden_states)
- hidden_states = residual + hidden_states
- return hidden_states
- class BertDecoder(nn.Module):
- def __init__(self, config):
- super(BertDecoder, self).__init__()
- self.layer = nn.ModuleList(
- [DecodeLayer(config) for _ in range(config.dec_hidden_layers)])
- self.final_layernorm = BertLayerNorm(
- config.hidden_size, eps=config.layernorm_epsilon)
- self.fp32_layernorm = config.fp32_layernorm
- def forward(self,
- hidden_states,
- enc_hidden_states,
- enc_attn_mask,
- dec_attn_mask,
- checkpoint_activations=False,
- output_all_encoded_layers=False,
- is_infer=False):
- def custom(start, end):
- def custom_forward(*inputs):
- layers = self.layer[start:end]
- x_ = inputs[0]
- for layer in layers:
- x_ = layer(
- x_,
- inputs[1],
- inputs[2],
- dec_attn_mask * 1,
- is_infer=is_infer)
- return x_
- return custom_forward
- pre_enc_hidden = enc_hidden_states.data
- if checkpoint_activations:
- layer_idx = 0
- num_layers = len(self.layer)
- chunk_length = 1
- while layer_idx < num_layers:
- hidden_states = mpu.checkpoint(
- custom(layer_idx, layer_idx + chunk_length), hidden_states,
- enc_hidden_states, enc_attn_mask * 1)
- enc_hidden_states.data = pre_enc_hidden
- layer_idx += chunk_length
- else:
- for i, layer_module in enumerate(self.layer):
- hidden_states = layer_module(
- hidden_states,
- enc_hidden_states,
- enc_attn_mask,
- dec_attn_mask,
- is_infer=is_infer)
- previous_type = hidden_states.type()
- if self.fp32_layernorm:
- hidden_states = hidden_states.float()
- hidden_states = self.final_layernorm(hidden_states)
- if self.fp32_layernorm:
- hidden_states = hidden_states.type(previous_type)
- return [hidden_states]
- class DecodeModel(PreTrainedBertModel):
- def __init__(self, config):
- super(DecodeModel, self).__init__(config)
- self.decoder = BertDecoder(config)
- self.apply(self.init_bert_weights)
- def forward(self,
- embeddings,
- sequence_output,
- decode_input_ids,
- position_ids=None,
- enc_attn_mask=None,
- dec_attn_mask=None,
- checkpoint_activations=False,
- is_infer=False):
- extended_attention_mask = enc_attn_mask.unsqueeze(1).unsqueeze(2)
- extended_attention_mask = extended_attention_mask.to(
- dtype=next(self.decoder.parameters()).dtype) # fp16 compatibility
- extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
- embedding_output = embeddings(decode_input_ids)
- sequence_output = self.decoder(
- embedding_output,
- sequence_output,
- extended_attention_mask,
- dec_attn_mask,
- checkpoint_activations=False,
- is_infer=is_infer)
- return sequence_output[-1]
- class PalmForPreTraining(PreTrainedBertModel):
- def __init__(self, config):
- super(PalmForPreTraining, self).__init__(config)
- self.bert = BertModel(config)
- self.cls = BertPreTrainingHeads(
- config, self.bert.embeddings.word_embeddings.weight)
- self.decoder = DecodeModel(config)
- self.apply(self.init_bert_weights)
- def forward(self,
- input_ids,
- token_type_ids=None,
- attention_mask=None,
- decode_input_ids=None,
- position_ids=None,
- decode_attention_mask=None,
- lm_labels=None,
- checkpoint_activations=False,
- is_infer=False,
- sequence_output=None,
- parallel_output=True):
- if sequence_output is None:
- sequence_output, pooled_output = self.bert(
- input_ids,
- token_type_ids,
- attention_mask,
- output_all_encoded_layers=False,
- checkpoint_activations=checkpoint_activations)
- prediction_scores, seq_relationship_score = self.cls(
- sequence_output, pooled_output)
- else:
- prediction_scores = None
- sequence_output = sequence_output.to(
- dtype=next(self.decoder.parameters()).dtype)
- if attention_mask is None:
- attention_mask = torch.ones_like(input_ids)
- decode_output = self.decoder(
- self.bert.embeddings,
- sequence_output,
- decode_input_ids,
- position_ids,
- attention_mask,
- decode_attention_mask,
- checkpoint_activations=checkpoint_activations,
- is_infer=is_infer)
- transformer_output_parallel = mpu.copy_to_model_parallel_region(
- decode_output)
- logits_parallel = F.linear(transformer_output_parallel,
- self.bert.embeddings.word_embeddings.weight)
- if parallel_output:
- return prediction_scores, logits_parallel
- if is_infer:
- return prediction_scores, mpu.gather_from_model_parallel_region(
- logits_parallel), sequence_output
- return prediction_scores, mpu.gather_from_model_parallel_region(
- logits_parallel)
- class PlugModel(torch.nn.Module):
- """
- The bare Plug Model transformer outputting raw hidden-states without any specific head on top.
- This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
- and behavior.
- Parameters:
- config ([`PlugNLGConfig`]): Model configuration class with all the parameters of the model.
- Initializing with a config file does not load the weights associated with the model, only the
- configuration. Check out the [`~DistributedPlug.initialize_model`] method to load the model weights.
- Examples:
- >>> # The PLUG model has 27B parameters and usually need to run on multiple GPUs. The example given
- >>> # here only initializes a slice of the model on a single GPU.
- >>> # Check out the [`~DistributedPipeline.__init__`] method to initialize entire PLUG model.
- >>> from modelscope.models.nlp.plug import PlugNLGConfig, PlugModel
- >>> # Initializing a Plug configuration
- >>> configuration = PlugNLGConfig()
- >>> # Initializing a model from the configuration
- >>> model = PlugModel(configuration)
- """
- def __init__(self, config):
- super(PlugModel, self).__init__()
- self.config = config
- self.model = PalmForPreTraining(self.config)
- def forward(self,
- input_tokens,
- token_type_ids=None,
- attention_mask=None,
- target_tokens=None,
- position_ids=None,
- decode_attention_mask=None,
- checkpoint_activations=False,
- is_infer=False,
- sequence_output=None,
- parallel_output=True):
- """
- Parameters:
- input_tokens (`torch.LongTensor` of shape `(batch_size, input_tokens_length)`):
- `input_tokens_length` = `sequence_length`. Indices of input sequence tokens in the vocabulary.
- Indices can be obtained using transformers [`BertTokenizer`]. See
- [`TextGenerationPreprocessor.__call__`] for details.
- token_type_ids (`torch.LongTensor` of shape `(batch_size, input_tokens_length)`, *optional*, defaults to
- None):
- Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
- 1]`:
- - 0 corresponds to a *sentence A* token,
- - 1 corresponds to a *sentence B* token.
- attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*, defaults to None):
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- - 1 for tokens that are **not masked**,
- - 0 for tokens that are **masked**.
- target_tokens (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*, defaults to None):
- Target token ids(labels) for language modeling. Note that the labels **are shifted** inside the model,
- i.e. you can set `target_tokens = input_tokens` Indices are selected in
- `[-100, 0, ..., config.vocab_size]` All labels set to `-100` are ignored (masked), the loss is only
- computed for labels in `[0, ..., config.vocab_size]`
- position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*, defaults to None):
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
- `[0, config.max_position_embeddings - 1]`.
- decode_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*, defaults
- to None):
- Mask to avoid performing attention on padding token indices of target tokens. Mask values selected in
- `[0, 1]`:
- - 1 for tokens that are **not masked**,
- - 0 for tokens that are **masked**.
- checkpoint_activations (`boolean`, *optional*, defaults to `False`):
- Whether gradient checkpointing is activated for this model or not.
- is_infer (`boolean`, *optional*, defaults to `False`):
- Whether or not to perform single inference.
- sequence_output (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*,
- defaults to None):
- Also known as last_hidden_state. Sequence of hidden-states at the output of the last layer of the
- model. A single forward() call can produce one single token. To generate the current token, the
- sequence_output generated by the `forward()` of the previous token is required.
- parallel_output (`boolean`, *optional*, defaults to `True`):
- To parallel return output, or gather it before return.
- """
- return self.model(
- input_tokens,
- token_type_ids,
- attention_mask,
- target_tokens,
- position_ids,
- decode_attention_mask,
- checkpoint_activations=checkpoint_activations,
- is_infer=is_infer,
- sequence_output=sequence_output,
- parallel_output=parallel_output)
- def state_dict(self, destination=None, prefix='', keep_vars=False):
- return self.model.state_dict(
- destination=destination, prefix=prefix, keep_vars=keep_vars)
- def load_state_dict(self, state_dict, strict=True):
- return self.model.load_state_dict(state_dict, strict=strict)
|