# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import copy import math import os from dataclasses import dataclass from typing import Any, Dict, List, Optional, Union import numpy as np import torch import torch.nn.functional as F from torch import Tensor, nn from torch.nn.init import xavier_uniform_ from transformers import (BertConfig, BertModel, BertTokenizer, RobertaConfig, RobertaModel, RobertaTokenizer) from transformers.activations import ACT2FN from transformers.modeling_utils import PreTrainedModel from modelscope.utils import logger as logging from .configuration import PlugConfig CONFIG_NAME = 'config.json' WEIGHTS_NAME = 'pytorch_model.bin' class MultiHeadedAttention(nn.Module): # SelfAttention """ Multi-Head Attention module from "Attention is All You Need" :cite:`DBLP:journals/corr/VaswaniSPUJGKP17`. Similar to standard `dot` attention but uses multiple attention distributions simultaneously to select relevant items. .. mermaid:: graph BT A[key] B[value] C[query] O[output] subgraph Attn D[Attn 1] E[Attn 2] F[Attn N] end A --> D C --> D A --> E C --> E A --> F C --> F D --> O E --> O F --> O B --> O Also includes several additional tricks. Args: head_count (int): number of parallel heads model_dim (int): the dimension of keys/values/queries, must be divisible by head_count dropout (float): dropout parameter """ def __init__(self, head_count, model_dim, dropout=0.1, use_final_linear=True): assert model_dim % head_count == 0 self.dim_per_head = model_dim // head_count self.model_dim = model_dim super().__init__() self.head_count = head_count self.linear_keys = nn.Linear(model_dim, head_count * self.dim_per_head) self.linear_values = nn.Linear(model_dim, head_count * self.dim_per_head) self.linear_query = nn.Linear(model_dim, head_count * self.dim_per_head) self.softmax = nn.Softmax(dim=-1) self.dropout = nn.Dropout(dropout) self.use_final_linear = use_final_linear if (self.use_final_linear): self.final_linear = nn.Linear(model_dim, model_dim) def forward(self, key, value, query, mask=None, layer_cache=None, type=None, predefined_graph_1=None, return_attn=False): """ Compute the context vector and the attention vectors. Args: key (`FloatTensor`): set of `key_len` key vectors `[batch, key_len, dim]` value (`FloatTensor`): set of `key_len` value vectors `[batch, key_len, dim]` query (`FloatTensor`): set of `query_len` query vectors `[batch, query_len, dim]` mask: binary mask indicating which keys have non-zero attention `[batch, query_len, key_len]` Returns: (`FloatTensor`, `FloatTensor`) : * output context vectors `[batch, query_len, dim]` * one of the attention vectors `[batch, query_len, key_len]` """ batch_size = key.size(0) dim_per_head = self.dim_per_head head_count = self.head_count def shape(x): """ projection """ return x.view(batch_size, -1, head_count, dim_per_head) \ .transpose(1, 2) def unshape(x): """ compute context """ return x.transpose(1, 2).contiguous() \ .view(batch_size, -1, head_count * dim_per_head) # 1) Project key, value, and query. if layer_cache is not None: if type == 'self': query, key, value = self.linear_query(query), self.linear_keys( query), self.linear_values(query) key = shape(key) value = shape(value) device = key.device if layer_cache['self_keys'] is not None: key = torch.cat((layer_cache['self_keys'].to(device), key), dim=2) if layer_cache['self_values'] is not None: value = torch.cat( (layer_cache['self_values'].to(device), value), dim=2) layer_cache['self_keys'] = key layer_cache['self_values'] = value elif type == 'context': query = self.linear_query(query) if layer_cache['memory_keys'] is None: key, value = self.linear_keys(key), self.linear_values( value) key = shape(key) value = shape(value) else: key, value = layer_cache['memory_keys'], layer_cache[ 'memory_values'] layer_cache['memory_keys'] = key layer_cache['memory_values'] = value else: key = self.linear_keys(key) value = self.linear_values(value) query = self.linear_query(query) key = shape(key) value = shape(value) query = shape(query) # 2) Calculate and scale scores. query = query / math.sqrt(dim_per_head) scores = torch.matmul(query, key.transpose(2, 3)) if mask is not None: mask = mask.unsqueeze(1).expand_as(scores) scores = scores.masked_fill(mask, float('-inf')) # 3) Apply attention dropout and compute context vectors. attn = self.softmax(scores) if (predefined_graph_1 is not None): attn_masked = attn[:, -1] * predefined_graph_1 attn_masked = attn_masked / ( torch.sum(attn_masked, 2).unsqueeze(2) + 1e-9) attn = torch.cat([attn[:, :-1], attn_masked.unsqueeze(1)], 1) drop_attn = self.dropout(attn) if (self.use_final_linear): context = unshape(torch.matmul(drop_attn, value)) output = self.final_linear(context) if return_attn: return output, attn else: return output else: context = torch.matmul(drop_attn, value) if return_attn: return context, attn else: return context class PositionwiseFeedForward(nn.Module): # Output """ A two-layer Feed-Forward-Network with residual layer norm. Args: d_model (int): the size of input for the first-layer of the FFN. d_ff (int): the hidden layer size of the second-layer of the FNN. dropout (float): dropout probability in :math:`[0, 1)`. """ def __init__(self, d_model, d_ff, dropout=0.1): super().__init__() self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) self.w_1 = nn.Linear(d_model, d_ff) self.actv = ACT2FN['gelu_new'] self.dropout_1 = nn.Dropout(dropout) self.w_2 = nn.Linear(d_ff, d_model) self.dropout_2 = nn.Dropout(dropout) def forward(self, x): inter = self.dropout_1(self.actv(self.w_1(self.layer_norm(x)))) output = self.dropout_2(self.w_2(inter)) return output + x class TransformerDecoderLayer(nn.Module): # Layer """ Args: d_model (int): the dimension of keys/values/queries in MultiHeadedAttention, also the input size of the first-layer of the PositionwiseFeedForward. heads (int): the number of heads for MultiHeadedAttention. d_ff (int): the second-layer of the PositionwiseFeedForward. dropout (float): dropout probability(0-1.0). self_attn_type (string): type of self-attention scaled-dot, average """ MAX_SIZE = 5000 def __init__(self, d_model, heads, d_ff, dropout): super().__init__() self.self_attn = MultiHeadedAttention(heads, d_model, dropout=dropout) self.context_attn = MultiHeadedAttention( heads, d_model, dropout=dropout) self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout) self.layer_norm_1 = nn.LayerNorm(d_model, eps=1e-6) self.layer_norm_2 = nn.LayerNorm(d_model, eps=1e-6) self.drop = nn.Dropout(dropout) mask = self._get_attn_subsequent_mask(self.MAX_SIZE) # Register self.mask as a buffer in TransformerDecoderLayer, so # it gets TransformerDecoderLayer's cuda behavior automatically. self.register_buffer('mask', mask) def forward(self, inputs, memory_bank, src_pad_mask, tgt_pad_mask, previous_input=None, layer_cache=None, step=None): """ Args: inputs (`FloatTensor`): `[batch_size x 1 x model_dim]` memory_bank (`FloatTensor`): `[batch_size x src_len x model_dim]` src_pad_mask (`LongTensor`): `[batch_size x 1 x src_len]` tgt_pad_mask (`LongTensor`): `[batch_size x 1 x 1]` Returns: (`FloatTensor`, `FloatTensor`, `FloatTensor`): * output `[batch_size x 1 x model_dim]` * attn `[batch_size x 1 x src_len]` * all_input `[batch_size x current_step x model_dim]` """ dec_mask = torch.gt( tgt_pad_mask.type(torch.uint8) + self.mask[:, :tgt_pad_mask.size(1), :tgt_pad_mask.size(1)].type( torch.uint8), 0) input_norm = self.layer_norm_1(inputs) all_input = input_norm if previous_input is not None: all_input = torch.cat((previous_input, input_norm), dim=1) dec_mask = None query = self.self_attn( all_input, all_input, input_norm, mask=dec_mask, layer_cache=layer_cache, type='self') query = self.drop(query) + inputs query_norm = self.layer_norm_2(query) mid, attn = self.context_attn( memory_bank, memory_bank, query_norm, mask=src_pad_mask, layer_cache=layer_cache, type='context', return_attn=True) output = self.feed_forward(self.drop(mid) + query) return output, attn, all_input def _get_attn_subsequent_mask(self, size): """ Get an attention mask to avoid using the subsequent info. Args: size: int Returns: (`LongTensor`): * subsequent_mask `[1 x size x size]` """ attn_shape = (1, size, size) subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8') subsequent_mask = torch.from_numpy(subsequent_mask) return subsequent_mask class PositionalEncoding(nn.Module): def __init__(self, dropout, dim, max_len=5000): super().__init__() pe = torch.zeros(max_len, dim) position = torch.arange(0, max_len).unsqueeze(1) div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.float) * -(math.log(10000.0) / dim))) pe[:, 0::2] = torch.sin(position.float() * div_term) pe[:, 1::2] = torch.cos(position.float() * div_term) pe = pe.unsqueeze(0) self.register_buffer('pe', pe) self.dropout = nn.Dropout(dropout) self.dim = dim def forward(self, emb, step=None): emb = emb * math.sqrt(self.dim) if (step): emb = emb + self.pe[:, step][:, None, :] else: emb = emb + self.pe[:, :emb.size(1)] emb = self.dropout(emb) return emb def get_emb(self, emb): return self.pe[:, :emb.size(1)] class TransformerDecoderState: def __init__(self, src: Tensor, cache_num_layers: int = -1): self.src: Tensor = src self.previous_input: Tensor = None self.previous_layer_inputs: Tensor = None self.cache: Optional[Dict[str, Any]] = None if cache_num_layers != -1: self._init_cache(cache_num_layers) def update_state(self, new_input, previous_layer_inputs): self.previous_input = new_input self.previous_layer_inputs = previous_layer_inputs self.cache = None def _init_cache(self, num_layers): self.cache = {} for layer in range(num_layers): layer_cache = {'memory_keys': None, 'memory_values': None} layer_cache['self_keys'] = None layer_cache['self_values'] = None self.cache['layer_{}'.format(layer)] = layer_cache def map_batch_fn(self, fn): def _recursive_map(struct, batch_dim=0): for k, v in struct.items(): if v is not None: if isinstance(v, dict): _recursive_map(v) else: struct[k] = fn(v, batch_dim) self.src = fn(self.src, 0) if self.cache is not None: _recursive_map(self.cache) class TransformerDecoder(nn.Module): # Decoder """ The Transformer decoder from "Attention is All You Need". .. mermaid:: graph BT A[input] B[multi-head self-attn] BB[multi-head src-attn] C[feed forward] O[output] A --> B B --> BB BB --> C C --> O Args: num_layers (int): number of encoder layers. d_model (int): size of the model heads (int): number of heads d_ff (int): size of the inner FF layer dropout (float): dropout parameters embeddings (:obj:`onmt.modules.Embeddings`): embeddings to use, should have positional encodings attn_type (str): if using a separate copy attention """ decoder_type = 'transformer' def __init__(self, num_layers, d_model, heads, d_ff, dropout, embeddings): super().__init__() # Basic attributes. self.num_layers = num_layers self.embeddings = embeddings self.pos_emb = PositionalEncoding(dropout, self.embeddings.embedding_dim) # Build TransformerDecoder. self.transformer_layers = nn.ModuleList([ TransformerDecoderLayer(d_model, heads, d_ff, dropout) for _ in range(num_layers) ]) self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) self.state = None def forward(self, state: TransformerDecoderState, tgt: Tensor, memory_bank: Tensor, step: int = None, memory_masks: Tensor = None): src_words = state.src tgt_words = tgt src_batch, src_len = src_words.size() tgt_batch, tgt_len = tgt_words.size() # Run the forward pass of the TransformerDecoder. # emb = self.embeddings(tgt, step=step) emb = self.embeddings(tgt) assert emb.dim() == 3 # len x batch x embedding_dim output = self.pos_emb(emb, step) src_memory_bank = memory_bank padding_idx = self.embeddings.padding_idx tgt_pad_mask = tgt_words.data.eq(padding_idx).unsqueeze(1) \ .expand(tgt_batch, tgt_len, tgt_len) if (memory_masks is not None): src_len = memory_masks.size(-1) src_pad_mask = memory_masks.expand(src_batch, tgt_len, src_len) else: src_pad_mask = src_words.data.eq(padding_idx).unsqueeze(1) \ .expand(src_batch, tgt_len, src_len) if state.cache is None: saved_inputs = [] attns = [] for i in range(self.num_layers): prev_layer_input = None if state.cache is None: if state.previous_input is not None: prev_layer_input = state.previous_layer_inputs[i] output, attn, all_input \ = self.transformer_layers[i]( output, src_memory_bank, src_pad_mask, tgt_pad_mask, previous_input=prev_layer_input, layer_cache=state.cache['layer_{}'.format(i)] if state.cache is not None else None, step=step) if state.cache is None: saved_inputs.append(all_input) attns.append(attn) if state.cache is None: saved_inputs = torch.stack(saved_inputs) output = self.layer_norm(output) # Process the result and update the attentions. if state.cache is None: state.update_state(tgt, saved_inputs) return output, attns, state class PlugPointerGenerator(nn.Module): def __init__(self, hidden_size, vocab_size): super().__init__() self.dense = nn.Linear(hidden_size, vocab_size) self.gen_func = nn.LogSoftmax(-1) def forward(self, x): x = self.dense(x) x = self.gen_func(x) return x class PlugPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ config_class = PlugConfig base_model_prefix = 'plug' @classmethod def from_pretrained( cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]]): config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME) config = PlugConfig.from_json_file(config_file) if os.path.isfile( config_file) else PlugConfig() config.encoder_pth = os.path.join(pretrained_model_name_or_path, config.encoder_pth) checkpoint_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) checkpoint = torch.load(checkpoint_file) if os.path.isfile( checkpoint_file) else None return cls(config, checkpoint) class PlugModel(PlugPreTrainedModel): # Model def __init__(self, config, checkpoint=None): super().__init__(config) self.config = config if config.encoder == 'bert' or config.encoder == 'zh_bert': self.bert = BertModel( BertConfig.from_pretrained(config.encoder_pth)) elif config.encoder == 'roberta': self.bert = RobertaModel( RobertaConfig.from_pretrained(config.encoder_pth)) if (config.max_pos > 512): my_pos_embeddings = nn.Embedding( config.max_pos, self.bert.model.config.hidden_size) my_pos_embeddings.weight.data[: 512] = self.bert.embeddings.position_embeddings.weight.data my_pos_embeddings.weight.data[ 512:] = self.bert.embeddings.position_embeddings.weight.data[ -1][None, :].repeat(config.max_pos - 512, 1) self.bert.model.embeddings.position_embeddings = my_pos_embeddings self.vocab_size = self.bert.config.vocab_size tgt_embeddings = nn.Embedding( self.vocab_size, self.bert.config.hidden_size, padding_idx=1 if config.encoder == 'roberta' else 0) if config.share_emb: tgt_embeddings.weight = copy.deepcopy( self.bert.model.embeddings.word_embeddings.weight) self.decoder = TransformerDecoder( config.dec_layers, config.dec_hidden_size, heads=config.dec_heads, d_ff=config.dec_ff_size, dropout=config.dec_dropout, embeddings=tgt_embeddings) self.generator = PlugPointerGenerator(config.dec_hidden_size, self.vocab_size) self.generator.dense.weight = self.decoder.embeddings.weight if checkpoint is not None: for key in list(checkpoint['model'].keys()): if key.startswith('module.'): checkpoint['model'][key.replace( 'module.', '')] = checkpoint['model'][key] checkpoint['model'].pop(key) if key.startswith('plug.'): checkpoint['model'][key.replace( 'plug.', '')] = checkpoint['model'][key] checkpoint['model'].pop(key) msg = self.load_state_dict(checkpoint['model'], strict=False) print(msg) else: for module in self.decoder.modules(): if isinstance(module, (nn.Linear, nn.Embedding)): module.weight.data.normal_(mean=0.0, std=0.02) elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() for p in self.generator.parameters(): if p.dim() > 1: xavier_uniform_(p) else: p.data.zero_() if config.use_bert_emb: if config.encoder == 'roberta': tgt_embeddings = nn.Embedding( self.vocab_size, self.bert.config.hidden_size, padding_idx=1) else: tgt_embeddings = nn.Embedding( self.vocab_size, self.bert.config.hidden_size, padding_idx=0) tgt_embeddings.weight = copy.deepcopy( self.bert.embeddings.word_embeddings.weight) self.decoder.embeddings = tgt_embeddings self.generator.dense.weight = self.decoder.embeddings.weight def forward(self, src, tgt, mask_src, token_type_ids): top_vec, _ = self.bert( src, mask_src, token_type_ids=token_type_ids, return_dict=False) state = TransformerDecoderState(src) decoder_outputs, attns, _ = self.decoder(state, tgt[:, :-1], top_vec) return decoder_outputs, attns[-1], top_vec class LabelSmoothingLoss(nn.Module): """ With label smoothing, KL-divergence between q_{smoothed ground truth prob.}(w) and p_{prob. computed by model}(w) is minimized. """ def __init__(self, label_smoothing, tgt_vocab_size, ignore_index=-100): assert 0.0 < label_smoothing <= 1.0 self.padding_idx = ignore_index super(LabelSmoothingLoss, self).__init__() smoothing_value = label_smoothing / (tgt_vocab_size - 2) one_hot = torch.full((tgt_vocab_size, ), smoothing_value) one_hot[self.padding_idx] = 0 self.register_buffer('one_hot', one_hot.unsqueeze(0)) self.confidence = 1.0 - label_smoothing def forward(self, output, target): """ output (FloatTensor): batch_size x n_classes target (LongTensor): batch_size """ model_prob = self.one_hot.repeat(target.size(0), 1) model_prob.scatter_(1, target.unsqueeze(1), self.confidence) model_prob.masked_fill_((target == self.padding_idx).unsqueeze(1), 0) return F.kl_div(output, model_prob, reduction='sum') class NMTLossCompute(nn.Module): """ Standard NMT Loss Computation. """ def __init__(self, generator, symbols, vocab_size, label_smoothing=0.0): super().__init__() self.generator = generator self.padding_idx = symbols['PAD'] if label_smoothing > 0: self.criterion = LabelSmoothingLoss( label_smoothing, vocab_size, ignore_index=self.padding_idx) else: self.criterion = nn.NLLLoss( ignore_index=self.padding_idx, reduction='sum') def _bottle(self, _v): return _v.view(-1, _v.size(2)) def _unbottle(self, _v, batch_size): return _v.view(-1, batch_size, _v.size(1)) def forward(self, tgt, output): target = tgt[:, 1:] batch_size, decoder_length = target.size(0), target.size(1) normalization = target.ne(self.padding_idx).sum() bottled_output = self._bottle(output) scores = self.generator(bottled_output) gtruth = target.contiguous().view(-1) loss = self.criterion(scores, gtruth) loss = loss.div(float(normalization)) return loss, scores.view(batch_size, decoder_length, -1) class PlugForConditionalGeneration(PlugPreTrainedModel): @dataclass class Batch: batch_size: int src: torch.Tensor tgt: torch.Tensor mask_src: torch.Tensor token_type_ids: torch.Tensor query_id: List[None] = None src_str: List[List[str]] = None tgt_str: List[str] = None def __init__(self, config, checkpoint=None, dataset: str = 'default'): super().__init__(config) self.logger = logging.get_logger() self.config = config if config.encoder == 'roberta': tokenizer = RobertaTokenizer.from_pretrained( config.encoder_pth, do_lower_case=False) symbols = { 'BOS': tokenizer.cls_token_id, 'EOS': tokenizer.sep_token_id, 'PAD': tokenizer.pad_token_id, 'EOQ': tokenizer.unk_token_id } elif config.encoder == 'bert' or config.encoder == 'zh_bert': tokenizer = BertTokenizer.from_pretrained( config.encoder_pth, do_lower_case=True) symbols = { 'BOS': tokenizer.vocab['[CLS]'], 'EOS': tokenizer.vocab['[SEP]'], 'PAD': tokenizer.vocab['[PAD]'], 'EOQ': tokenizer.vocab['[unused2]'] } self.tokenizer = tokenizer self.symbols = symbols self.plug = PlugModel(config, checkpoint) self.loss = NMTLossCompute(self.plug.generator, symbols, self.plug.vocab_size, config.label_smoothing) # for generation self.config.dataset = dataset self.start_token = self.symbols['BOS'] self.end_token = self.symbols['EOS'] def forward(self, src, tgt, mask_src=None, token_type_ids=None): if mask_src is None: mask_src = src.ne(self.symbols['PAD']).long() output = self.plug(src, tgt, mask_src, token_type_ids)[0] loss = self.loss(tgt, output) return loss def translate_batch(self, batch: 'Batch', fast: bool = False, *args, **kwargs): """ Translate a batch of sentences. Mostly a wrapper around :obj:`Beam`. Args: batch (:obj:`Batch`): a batch from a dataset object data (:obj:`Dataset`): the dataset object fast (bool): enables fast beam search (may not support all features) Todo: Shouldn't need the original dataset. """ self.plug.eval() with torch.no_grad(): return self._fast_translate_batch(batch, *args, **kwargs) def _tile(self, x, count, dim=0): perm = list(range(len(x.size()))) if dim != 0: perm[0], perm[dim] = perm[dim], perm[0] x = x.permute(perm).contiguous() out_size = list(x.size()) out_size[0] *= count batch = x.size(0) x = x.view(batch, -1) \ .transpose(0, 1) \ .repeat(count, 1) \ .transpose(0, 1) \ .contiguous() \ .view(*out_size) if dim != 0: x = x.permute(perm).contiguous() return x def _top_k_top_p_filtering(self, logits, top_k=10, top_p=1.0, filter_value=-float('Inf'), min_tokens_to_keep=1): if top_k > 0: top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check # Remove all tokens with a probability less than the last token of the top-k indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] logits[indices_to_remove] = filter_value if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum( F.softmax(sorted_logits, dim=-1), dim=-1) # Remove tokens with cumulative probability above the threshold (token with 0 are kept) sorted_indices_to_remove = cumulative_probs > top_p if min_tokens_to_keep > 1: # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 # Shift the indices to the right to keep also the first token above the threshold sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[ ..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 # scatter sorted tensors to original indexing indices_to_remove = sorted_indices_to_remove.scatter( 1, sorted_indices, sorted_indices_to_remove) logits[indices_to_remove] = filter_value return logits def _fast_translate_batch(self, batch: 'Batch', max_length: int = 80, min_length: int = 10, bad_words_ids=None, early_stopping=True, num_beams=3, length_penalty=1.2, repetition_penalty=1.2, no_repeat_ngram_size=4, do_sample=False, temperature=1.0, top_k=0, top_p=1.0, *args, **kwargs): # TODO: faster code path for beam_size == 1. # TODO: support these blacklisted features. num_beams = num_beams batch_size = batch.batch_size src = batch.src mask_src = batch.mask_src token_type_ids = batch.token_type_ids src_features, _ = self.plug.bert( src, mask_src, token_type_ids=token_type_ids, return_dict=False) state = TransformerDecoderState(src, self.plug.decoder.num_layers) device = src_features.device # Tile states and memory beam_size times. state.map_batch_fn( lambda state, dim: self._tile(state, num_beams, dim=dim)) src_features = self._tile(src_features, num_beams, dim=0) batch_offset = torch.arange( batch_size, dtype=torch.long, device=device) beam_offset = torch.arange( 0, batch_size * num_beams, step=num_beams, dtype=torch.long, device=device) alive_seq = torch.full([batch_size * num_beams, 1], self.start_token, dtype=torch.long, device=device) # cal bad_words_ids pre dict bad_words_prefix_dict = {} bad_words_prefix_len = set([]) if bad_words_ids is not None: for bw_id in bad_words_ids: key = tuple(bw_id[:-1]) value = bw_id[-1] bad_words_prefix_dict[key] = bad_words_prefix_dict.get( key, []) + [value] bad_words_prefix_len.add(len(key)) # Give full probability to the first beam on the first step. topk_log_probs = ( torch.tensor( [0.0] + [float('-inf')] * (num_beams - 1), device=device).repeat(batch_size)) # Structure that holds finished hypotheses. hypotheses = [[] for _ in range(batch_size)] # noqa: F812 results = {} results['predictions'] = [[] for _ in range(batch_size)] # noqa: F812 results['scores'] = [[] for _ in range(batch_size)] # noqa: F812 results['gold_score'] = [0] * batch_size results['batch'] = batch for step in range(max_length): # self.logger.info(f'step: {step + 1} / {max_length}') decoder_input = alive_seq[:, -1].view(1, -1) # Decoder forward. decoder_input = decoder_input.transpose(0, 1) dec_out, attns, state = self.plug.decoder( state, decoder_input, src_features, step=step) # Generator forward. log_probs = self.plug.generator.forward( dec_out.transpose(0, 1).squeeze(0)) vocab_size = log_probs.size(-1) if step < min_length: log_probs[:, self.end_token] = -1e20 # filter bad word if len(bad_words_prefix_dict) > 0: # cal bad word banned token: batch_size * num_beams num_hypos = alive_seq.size(0) bad_word_banned_token = [] for i in range(num_hypos): curr_banned_token = [] for pre_len in bad_words_prefix_len: pre_key = tuple(alive_seq[i, step + 1 - pre_len:step + 1].cpu().numpy().tolist()) curr_banned_token += bad_words_prefix_dict.get( pre_key, []) bad_word_banned_token.append(set(curr_banned_token)) # set banned word prob=-1e20 assert log_probs.size(0) == num_hypos for i in range(num_hypos): for banned_token in bad_word_banned_token[i]: log_probs[i, banned_token] = -1e20 # do repetition_penalty if repetition_penalty > 1.0: """repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858). """ # calculate prev_output_tokens for repetition_penalty: batch_size * num_beams prev_output_tokens = self.calc_banned_tokens( alive_seq, alive_seq.size(0), no_repeat_ngram_size, step + 1) # batch_size * num_beams for i in range(log_probs.size(0)): for previous_token in set(prev_output_tokens[i]): if log_probs[i, previous_token] < 0: log_probs[i, previous_token] *= repetition_penalty else: log_probs[i, previous_token] /= repetition_penalty # Multiply probs by the beam probability. curr_length_penalty = (step + 1)**length_penalty # ''' if do_sample: _scores = log_probs / temperature _scores = self._top_k_top_p_filtering( _scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=1 ) # (batch_size * num_beams, vocab_size) # Sample 2 next words for each beam (so we have some spare tokens # and match output of greedy beam search) topk_ids = torch.multinomial( F.softmax(_scores, dim=-1), num_samples=1) # (batch_size * num_beams, 2) # Compute next scores _scores = F.log_softmax( _scores, dim=1) # (batch_size * num_beams, vocab_size) _scores += topk_log_probs.view(-1).unsqueeze(1) _scores = _scores / curr_length_penalty topk_scores = torch.gather( _scores, -1, topk_ids) # (batch_size * num_beams, 2) # log_probs += # (batch_size * num_beams, 2) # Match shape of greedy beam search topk_ids = topk_ids.view( -1, num_beams) # (batch_size, 2 * num_beams) topk_scores = topk_scores.view( -1, num_beams) # (batch_size, 2 * num_beams) # ''' else: log_probs += topk_log_probs.view(-1).unsqueeze(1) curr_scores = log_probs / curr_length_penalty curr_scores = curr_scores.reshape(-1, num_beams * vocab_size) topk_scores, topk_ids = curr_scores.topk(num_beams, dim=-1) if (self.config.block_trigram): cur_len = alive_seq.size(1) if (cur_len > 3): for i in range(alive_seq.size(0)): fail = False words = [int(w) for w in alive_seq[i]] if self.config.encoder == 'roberta': # words = [self.vocab.convert_ids_to_tokens[w] for w in words] words = self.tokenizer.decode( words).strip().split() else: words = [ self.tokenizer.ids_to_tokens[w] for w in words ] words = ' '.join(words).replace(' ##', '').split() if (len(words) <= 3): continue trigrams = [(words[i - 1], words[i], words[i + 1]) for i in range(1, len(words) - 1)] trigram = tuple(trigrams[-1]) if trigram in trigrams[:-1]: fail = True if fail: curr_scores[i] = -10e20 # Recover log probs. topk_log_probs = topk_scores * curr_length_penalty # Resolve beam origin and true word ids. # topk_beam_index = topk_ids.div(vocab_size) topk_beam_index = topk_ids // vocab_size topk_ids = topk_ids.fmod(vocab_size) # Map beam_index to batch_index in the flat representation. batch_index = ( topk_beam_index + beam_offset[:topk_beam_index.size(0)].unsqueeze(1)) select_indices = batch_index.view(-1) # Append last prediction. alive_seq = torch.cat([ alive_seq.index_select(0, select_indices), topk_ids.view(-1, 1) ], -1) is_finished = topk_ids.eq(self.end_token) if step + 1 == max_length: is_finished.fill_(self.end_token) # End condition is top beam is finished. end_condition = is_finished[:, 0].eq(1) # Save finished hypotheses. if is_finished.any(): predictions = alive_seq.view(-1, num_beams, alive_seq.size(-1)) for i in range(is_finished.size(0)): b = batch_offset[i] if end_condition[i]: is_finished[i].fill_(self.end_token) finished_hyp = is_finished[i].nonzero().view(-1) # Store finished hypotheses for this batch. for j in finished_hyp: hypotheses[b].append( (topk_scores[i, j], predictions[i, j, 1:])) if early_stopping and len(hypotheses) == num_beams: end_condition[i] = True # If the batch reached the end, save the n_best hypotheses. if end_condition[i]: best_hyp = sorted( hypotheses[b], key=lambda x: x[0], reverse=True) if self.config.dataset == 'qg_ranking_test' or ( self.config.dataset == 'paraphrase' and not self.config.sample_topk): for each in best_hyp[:num_beams]: score, pred = each results['scores'][b].append(score) results['predictions'][b].append(pred) else: score, pred = best_hyp[0] results['scores'][b].append(score) results['predictions'][b].append(pred) non_finished = end_condition.eq(0).nonzero().view(-1) # If all sentences are translated, no need to go further. if len(non_finished) == 0: break # Remove finished batches for the next step. topk_log_probs = topk_log_probs.index_select(0, non_finished) batch_index = batch_index.index_select(0, non_finished) batch_offset = batch_offset.index_select(0, non_finished) alive_seq = predictions.index_select(0, non_finished) \ .view(-1, alive_seq.size(-1)) # Reorder states. select_indices = batch_index.view(-1) src_features = src_features.index_select(0, select_indices) state.map_batch_fn( lambda state, dim: state.index_select(dim, select_indices)) return results def calc_banned_tokens(self, prev_input_ids, num_hypos, no_repeat_ngram_size, cur_len): # Copied from fairseq for no_repeat_ngram in beam_search""" if cur_len + 1 < no_repeat_ngram_size: # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet return [[] for _ in range(num_hypos)] generated_ngrams = [{} for _ in range(num_hypos)] for idx in range(num_hypos): gen_tokens = prev_input_ids[idx].cpu().numpy().tolist() generated_ngram = generated_ngrams[idx] for ngram in zip( *[gen_tokens[i:] for i in range(no_repeat_ngram_size)]): prev_ngram_tuple = tuple(ngram[:-1]) generated_ngram[prev_ngram_tuple] = generated_ngram.get( prev_ngram_tuple, []) + [ngram[-1]] def _get_generated_ngrams(hypo_idx): # Before decoding the next token, prevent decoding of ngrams that have already appeared start_idx = cur_len + 1 - no_repeat_ngram_size ngram_idx = tuple( prev_input_ids[hypo_idx, start_idx:cur_len].cpu().numpy().tolist()) return generated_ngrams[hypo_idx].get(ngram_idx, []) banned_tokens = [ _get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos) ] return banned_tokens def translate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor = None, token_type_ids=None, *args, **kwargs) -> Dict[str, torch.Tensor]: if attention_mask is None: attention_mask = input_ids.ne(self.symbols['PAD']).long() batch = self.Batch( batch_size=input_ids.size()[0], src=input_ids, tgt=None, token_type_ids=token_type_ids, mask_src=attention_mask) translation_batch = self.translate_batch(batch, *args, **kwargs) preds = translation_batch['predictions'] return {'predictions': preds}