| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574 |
- # Part of the implementation is borrowed and modified from THUMT,
- # publicly available at https://github.com/THUNLP-MT/THUMT
- # Copyright 2017-2022 The Alibaba MT Team Authors. All rights reserved.
- import math
- from collections import namedtuple
- from typing import Dict
- import tensorflow as tf
- from modelscope.metainfo import Models
- from modelscope.models.base import Model, Tensor
- from modelscope.models.builder import MODELS
- from modelscope.utils.constant import Tasks
- __all__ = ['CsanmtForTranslation']
- @MODELS.register_module(Tasks.translation, module_name=Models.translation)
- class CsanmtForTranslation(Model):
- def __init__(self, model_dir, *args, **kwargs):
- """
- Args:
- params (dict): the model configuration.
- """
- super().__init__(model_dir, *args, **kwargs)
- self.params = kwargs
- print(self.params)
- def __call__(self,
- input: Dict[str, Tensor],
- label: Dict[str, Tensor] = None,
- prefix: Dict[str, Tensor] = None,
- prefix_hit: Dict[bool, Tensor] = None) -> Dict[str, Tensor]:
- """return the result by the model
- Args:
- input: the preprocessed input source sequence
- label: the ground truth target data for model training
- prefix: the preprocessed input target prefix sequence for interactive translation
- prefix_hit: the preprocessed target prefix subword vector for interactive translation
- Returns:
- output_seqs: output sequence of target ids
- """
- if label is None:
- with tf.compat.v1.variable_scope('NmtModel'):
- output_seqs, output_scores = self.beam_search(
- {
- 'input_wids': input,
- 'prefix_wids': prefix,
- 'prefix_hit': prefix_hit
- }, self.params)
- return {
- 'output_seqs': output_seqs,
- 'output_scores': output_scores,
- }
- else:
- train_op, loss = self.transformer_model_train_fn(input, label)
- return {
- 'train_op': train_op,
- 'loss': loss,
- }
- def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
- """
- Run the forward pass for a model.
- Args:
- input (Dict[str, Tensor]): the dict of the model inputs for the forward method
- Returns:
- Dict[str, Tensor]: output from the model forward pass
- """
- ...
- def encoding_graph(self, features, params):
- src_vocab_size = params['src_vocab_size']
- hidden_size = params['hidden_size']
- initializer = tf.compat.v1.random_normal_initializer(
- 0.0, hidden_size**-0.5, dtype=tf.float32)
- if params['shared_source_target_embedding']:
- with tf.compat.v1.variable_scope(
- 'Shared_Embedding', reuse=tf.compat.v1.AUTO_REUSE):
- src_embedding = tf.compat.v1.get_variable(
- 'Weights', [src_vocab_size, hidden_size],
- initializer=initializer)
- else:
- with tf.compat.v1.variable_scope('Source_Embedding'):
- src_embedding = tf.compat.v1.get_variable(
- 'Weights', [src_vocab_size, hidden_size],
- initializer=initializer)
- src_bias = tf.compat.v1.get_variable('encoder_input_bias',
- [hidden_size])
- eos_padding = tf.zeros_like(features, dtype=tf.int64)[:, :1]
- src_seq = tf.concat([features, eos_padding], 1)
- src_mask = tf.cast(tf.not_equal(src_seq, 0), dtype=tf.float32)
- shift_src_mask = src_mask[:, :-1]
- shift_src_mask = tf.pad(
- tensor=shift_src_mask,
- paddings=[[0, 0], [1, 0]],
- constant_values=1)
- encoder_input = tf.gather(src_embedding, tf.cast(src_seq, tf.int32))
- encoder_input = encoder_input * (hidden_size**0.5)
- if params['position_info_type'] == 'absolute':
- encoder_input = add_timing_signal(encoder_input)
- encoder_input = tf.multiply(encoder_input,
- tf.expand_dims(shift_src_mask, 2))
- encoder_input = tf.nn.bias_add(encoder_input, src_bias)
- encoder_self_attention_bias = attention_bias(shift_src_mask, 'masking')
- if params['residual_dropout'] > 0.0:
- encoder_input = tf.nn.dropout(
- encoder_input, rate=params['residual_dropout'])
- # encode
- encoder_output = transformer_encoder(encoder_input,
- encoder_self_attention_bias,
- shift_src_mask, params)
- return encoder_output, encoder_self_attention_bias
- def semantic_encoding_graph(self, features, params, name=None):
- hidden_size = params['hidden_size']
- initializer = tf.compat.v1.random_normal_initializer(
- 0.0, hidden_size**-0.5, dtype=tf.float32)
- scope = None
- if params['shared_source_target_embedding']:
- vocab_size = params['src_vocab_size']
- scope = 'Shared_Semantic_Embedding'
- elif name == 'source':
- vocab_size = params['src_vocab_size']
- scope = 'Source_Semantic_Embedding'
- elif name == 'target':
- vocab_size = params['trg_vocab_size']
- scope = 'Target_Semantic_Embedding'
- else:
- raise ValueError('error: no right name specified.')
- with tf.compat.v1.variable_scope(scope, reuse=tf.compat.v1.AUTO_REUSE):
- embedding_mat = tf.compat.v1.get_variable(
- 'Weights', [vocab_size, hidden_size], initializer=initializer)
- eos_padding = tf.zeros_like(features, dtype=tf.int64)[:, :1]
- input_seq = tf.concat([features, eos_padding], 1)
- input_mask = tf.cast(tf.not_equal(input_seq, 0), dtype=tf.float32)
- shift_input_mask = input_mask[:, :-1]
- shift_input_mask = tf.pad(
- tensor=shift_input_mask,
- paddings=[[0, 0], [1, 0]],
- constant_values=1)
- encoder_input = tf.gather(embedding_mat, tf.cast(input_seq, tf.int32))
- encoder_input = encoder_input * (hidden_size**0.5)
- encoder_input = tf.multiply(encoder_input,
- tf.expand_dims(shift_input_mask, 2))
- encoder_self_attention_bias = attention_bias(shift_input_mask,
- 'masking')
- if params['residual_dropout'] > 0.0:
- encoder_input = tf.nn.dropout(
- encoder_input, rate=params['residual_dropout'])
- # encode
- encoder_output = transformer_semantic_encoder(
- encoder_input, encoder_self_attention_bias, shift_input_mask,
- params)
- return encoder_output
- def build_contrastive_training_graph(self, features, labels, params):
- # representations
- source_name = 'source'
- target_name = 'target'
- if params['shared_source_target_embedding']:
- source_name = None
- target_name = None
- feature_output = self.semantic_encoding_graph(
- features, params, name=source_name)
- label_output = self.semantic_encoding_graph(
- labels, params, name=target_name)
- return feature_output, label_output
- def MGMC_sampling(self, x_embedding, y_embedding, params, epsilon=1e-12):
- K = params['num_of_samples']
- eta = params['eta']
- assert K % 2 == 0
- def get_samples(x_vector, y_vector):
- bias_vector = y_vector - x_vector
- w_r = tf.math.divide(
- tf.abs(bias_vector) - tf.reduce_min(
- input_tensor=tf.abs(bias_vector), axis=2, keepdims=True)
- + epsilon,
- tf.reduce_max(
- input_tensor=tf.abs(bias_vector), axis=2, keepdims=True)
- - tf.reduce_min(
- input_tensor=tf.abs(bias_vector), axis=2, keepdims=True)
- + 2 * epsilon)
- R = []
- for i in range(K // 2):
- omega = eta * tf.random.normal(tf.shape(input=bias_vector), 0.0, w_r) + \
- (1.0 - eta) * tf.random.normal(tf.shape(input=bias_vector), 0.0, 1.0)
- sample = x_vector + omega * bias_vector
- R.append(sample)
- return R
- ALL_SAMPLES = []
- ALL_SAMPLES = get_samples(x_embedding, y_embedding)
- ALL_SAMPLES.extend(get_samples(y_embedding, x_embedding))
- assert len(ALL_SAMPLES) == K
- return tf.concat(ALL_SAMPLES, axis=0)
- def decoding_graph(self,
- encoder_output,
- encoder_self_attention_bias,
- labels,
- params={},
- embedding_augmentation=None):
- trg_vocab_size = params['trg_vocab_size']
- hidden_size = params['hidden_size']
- initializer = tf.compat.v1.random_normal_initializer(
- 0.0, hidden_size**-0.5, dtype=tf.float32)
- if params['shared_source_target_embedding']:
- with tf.compat.v1.variable_scope(
- 'Shared_Embedding', reuse=tf.compat.v1.AUTO_REUSE):
- trg_embedding = tf.compat.v1.get_variable(
- 'Weights', [trg_vocab_size, hidden_size],
- initializer=initializer)
- else:
- with tf.compat.v1.variable_scope('Target_Embedding'):
- trg_embedding = tf.compat.v1.get_variable(
- 'Weights', [trg_vocab_size, hidden_size],
- initializer=initializer)
- eos_padding = tf.zeros_like(labels, dtype=tf.int64)[:, :1]
- trg_seq = tf.concat([labels, eos_padding], 1)
- trg_mask = tf.cast(tf.not_equal(trg_seq, 0), dtype=tf.float32)
- shift_trg_mask = trg_mask[:, :-1]
- shift_trg_mask = tf.pad(
- tensor=shift_trg_mask,
- paddings=[[0, 0], [1, 0]],
- constant_values=1)
- decoder_input = tf.gather(trg_embedding, tf.cast(trg_seq, tf.int32))
- decoder_input *= hidden_size**0.5
- decoder_self_attention_bias = attention_bias(
- tf.shape(input=decoder_input)[1], 'causal')
- decoder_input = tf.pad(
- tensor=decoder_input, paddings=[[0, 0], [1, 0], [0, 0]])[:, :-1, :]
- if params['position_info_type'] == 'absolute':
- decoder_input = add_timing_signal(decoder_input)
- decoder_input = tf.nn.dropout(
- decoder_input, rate=1 - (1.0 - params['residual_dropout']))
- # training
- decoder_output, attention_weights = transformer_decoder(
- decoder_input,
- encoder_output,
- decoder_self_attention_bias,
- encoder_self_attention_bias,
- states_key=None,
- states_val=None,
- embedding_augmentation=embedding_augmentation,
- params=params)
- logits = self.prediction(decoder_output, params)
- on_value = params['confidence']
- off_value = (1.0 - params['confidence']) / tf.cast(
- trg_vocab_size - 1, dtype=tf.float32)
- soft_targets = tf.one_hot(
- tf.cast(trg_seq, tf.int32),
- depth=trg_vocab_size,
- on_value=on_value,
- off_value=off_value)
- mask = tf.cast(shift_trg_mask, logits.dtype)
- xentropy = tf.nn.softmax_cross_entropy_with_logits(
- logits=logits, labels=tf.stop_gradient(soft_targets)) * mask
- loss = tf.reduce_sum(input_tensor=xentropy) / tf.reduce_sum(
- input_tensor=mask)
- return loss
- def build_training_graph(self,
- features,
- labels,
- params,
- feature_embedding=None,
- label_embedding=None):
- # encode
- encoder_output, encoder_self_attention_bias = self.encoding_graph(
- features, params)
- embedding_augmentation = None
- if feature_embedding is not None and label_embedding is not None:
- embedding_augmentation = self.MGMC_sampling(
- feature_embedding, label_embedding, params)
- encoder_output = tf.tile(encoder_output,
- [params['num_of_samples'], 1, 1])
- encoder_self_attention_bias = tf.tile(
- encoder_self_attention_bias,
- [params['num_of_samples'], 1, 1, 1])
- labels = tf.tile(labels, [params['num_of_samples'], 1])
- # decode
- loss = self.decoding_graph(
- encoder_output,
- encoder_self_attention_bias,
- labels,
- params,
- embedding_augmentation=embedding_augmentation)
- return loss
- def transformer_model_train_fn(self, features, labels):
- initializer = get_initializer(self.params)
- with tf.compat.v1.variable_scope('NmtModel', initializer=initializer):
- num_gpus = self.params['num_gpus']
- gradient_clip_norm = self.params['gradient_clip_norm']
- global_step = tf.compat.v1.train.get_global_step()
- print(global_step)
- # learning rate
- learning_rate = get_learning_rate_decay(
- self.params['learning_rate'], global_step, self.params)
- learning_rate = tf.convert_to_tensor(
- value=learning_rate, dtype=tf.float32)
- # optimizer
- if self.params['optimizer'] == 'sgd':
- optimizer = tf.compat.v1.train.GradientDescentOptimizer(
- learning_rate)
- elif self.params['optimizer'] == 'adam':
- optimizer = tf.compat.v1.train.AdamOptimizer(
- learning_rate=learning_rate,
- beta1=self.params['adam_beta1'],
- beta2=self.params['adam_beta2'],
- epsilon=self.params['adam_epsilon'])
- else:
- tf.compat.v1.logging.info('optimizer not supported')
- sys.exit()
- opt = MultiStepOptimizer(optimizer, self.params['update_cycle'])
- def fill_gpus(inputs, num_gpus):
- outputs = inputs
- for i in range(num_gpus):
- outputs = tf.concat([outputs, inputs], axis=0)
- outputs = outputs[:num_gpus, ]
- return outputs
- features = tf.cond(
- pred=tf.shape(input=features)[0] < num_gpus,
- true_fn=lambda: fill_gpus(features, num_gpus),
- false_fn=lambda: features)
- labels = tf.cond(
- pred=tf.shape(input=labels)[0] < num_gpus,
- true_fn=lambda: fill_gpus(labels, num_gpus),
- false_fn=lambda: labels)
- if num_gpus > 0:
- feature_shards = shard_features(features, num_gpus)
- label_shards = shard_features(labels, num_gpus)
- else:
- feature_shards = [features]
- label_shards = [labels]
- if num_gpus > 0:
- devices = ['gpu:%d' % d for d in range(num_gpus)]
- else:
- devices = ['cpu:0']
- multi_grads = []
- sharded_losses = []
- for i, device in enumerate(devices):
- with tf.device(device), tf.compat.v1.variable_scope(
- tf.compat.v1.get_variable_scope(),
- reuse=True if i > 0 else None):
- with tf.name_scope('%s_%d' % ('GPU', i)):
- feature_output, label_output = self.build_contrastive_training_graph(
- feature_shards[i], label_shards[i], self.params)
- mle_loss = self.build_training_graph(
- feature_shards[i], label_shards[i], self.params,
- feature_output, label_output)
- sharded_losses.append(mle_loss)
- tf.compat.v1.summary.scalar('mle_loss_{}'.format(i),
- mle_loss)
- # Optimization
- trainable_vars_list = [
- v for v in tf.compat.v1.trainable_variables()
- if 'Semantic_Embedding' not in v.name
- and 'mini_xlm_encoder' not in v.name
- ]
- grads_and_vars = opt.compute_gradients(
- mle_loss,
- var_list=trainable_vars_list,
- colocate_gradients_with_ops=True)
- multi_grads.append(grads_and_vars)
- total_loss = tf.add_n(sharded_losses) / len(sharded_losses)
- # Average gradients
- grads_and_vars = average_gradients(multi_grads)
- if gradient_clip_norm > 0.0:
- grads, var_list = list(zip(*grads_and_vars))
- grads, _ = tf.clip_by_global_norm(grads, gradient_clip_norm)
- grads_and_vars = zip(grads, var_list)
- train_op = opt.apply_gradients(
- grads_and_vars,
- global_step=tf.compat.v1.train.get_global_step())
- return train_op, total_loss
- def prediction(self, decoder_output, params):
- hidden_size = params['hidden_size']
- trg_vocab_size = params['trg_vocab_size']
- if params['shared_embedding_and_softmax_weights']:
- embedding_scope = 'Shared_Embedding' if params[
- 'shared_source_target_embedding'] else 'Target_Embedding'
- with tf.compat.v1.variable_scope(embedding_scope, reuse=True):
- weights = tf.compat.v1.get_variable('Weights')
- else:
- weights = tf.compat.v1.get_variable('Softmax',
- [tgt_vocab_size, hidden_size])
- shape = tf.shape(input=decoder_output)[:-1]
- decoder_output = tf.reshape(decoder_output, [-1, hidden_size])
- logits = tf.matmul(decoder_output, weights, transpose_b=True)
- logits = tf.reshape(logits, tf.concat([shape, [trg_vocab_size]], 0))
- return logits
- def inference_func(self,
- encoder_output,
- feature_output,
- encoder_self_attention_bias,
- trg_seq,
- states_key,
- states_val,
- params={},
- is_prefix=False):
- trg_vocab_size = params['trg_vocab_size']
- hidden_size = params['hidden_size']
- initializer = tf.compat.v1.random_normal_initializer(
- 0.0, hidden_size**-0.5, dtype=tf.float32)
- if params['shared_source_target_embedding']:
- with tf.compat.v1.variable_scope(
- 'Shared_Embedding', reuse=tf.compat.v1.AUTO_REUSE):
- trg_embedding = tf.compat.v1.get_variable(
- 'Weights', [trg_vocab_size, hidden_size],
- initializer=initializer)
- else:
- with tf.compat.v1.variable_scope('Target_Embedding'):
- trg_embedding = tf.compat.v1.get_variable(
- 'Weights', [trg_vocab_size, hidden_size],
- initializer=initializer)
- decoder_input = tf.gather(trg_embedding, tf.cast(trg_seq, tf.int32))
- decoder_input *= hidden_size**0.5
- decoder_self_attention_bias = attention_bias(
- tf.shape(input=decoder_input)[1], 'causal')
- decoder_input = tf.pad(
- tensor=decoder_input, paddings=[[0, 0], [1, 0], [0, 0]])[:, :-1, :]
- if params['position_info_type'] == 'absolute':
- decoder_input = add_timing_signal(decoder_input)
- if not is_prefix:
- decoder_input = decoder_input[:, -1:, :]
- decoder_self_attention_bias = decoder_self_attention_bias[:, :,
- -1:, :]
- decoder_output, attention_weights = transformer_decoder(
- decoder_input,
- encoder_output,
- decoder_self_attention_bias,
- encoder_self_attention_bias,
- states_key=states_key,
- states_val=states_val,
- embedding_augmentation=feature_output,
- params=params)
- if not is_prefix:
- decoder_output_last = decoder_output[:, -1, :]
- attention_weights_last = attention_weights[:, -1, :]
- else:
- decoder_output_last = decoder_output
- attention_weights_last = attention_weights
- if params['shared_embedding_and_softmax_weights']:
- embedding_scope = \
- 'Shared_Embedding' if params['shared_source_target_embedding'] else 'Target_Embedding'
- with tf.compat.v1.variable_scope(embedding_scope, reuse=True):
- weights = tf.compat.v1.get_variable('Weights')
- else:
- weights = tf.compat.v1.get_variable('Softmax',
- [trg_vocab_size, hidden_size])
- logits = tf.matmul(decoder_output_last, weights, transpose_b=True)
- log_prob = tf.nn.log_softmax(logits)
- return log_prob, attention_weights_last, states_key, states_val
- def beam_search(self, features, params):
- beam_size = params['beam_size']
- trg_vocab_size = params['trg_vocab_size']
- hidden_size = params['hidden_size']
- num_decoder_layers = params['num_decoder_layers']
- lp_rate = params['lp_rate']
- max_decoded_trg_len = params['max_decoded_trg_len']
- src_input = features['input_wids']
- if 'prefix_wids' in features:
- prefix = features['prefix_wids']
- prefix_hit = features['prefix_hit']
- else:
- prefix = None
- prefix_hit = None
- batch_size = tf.shape(src_input)[0]
- src_input = tile_to_beam_size(src_input, beam_size)
- src_input = merge_first_two_dims(src_input)
- if prefix is not None:
- prefix = tf.cast(tile_to_beam_size(prefix, beam_size), tf.int32)
- prefix_hit = tile_to_beam_size(prefix_hit, beam_size)
- encoder_output, encoder_self_attention_bias = self.encoding_graph(
- src_input, params)
- source_name = 'source'
- if params['shared_source_target_embedding']:
- source_name = None
- feature_output = self.semantic_encoding_graph(
- src_input, params, name=source_name)
- states_key = [
- tf.fill([batch_size, 0, hidden_size], 0.0)
- for layer in range(num_decoder_layers)
- ]
- states_val = [
- tf.fill([batch_size, 0, hidden_size], 0.0)
- for layer in range(num_decoder_layers)
- ]
- for layer in range(num_decoder_layers):
- states_key[layer].set_shape(
- tf.TensorShape([None, None, hidden_size]))
- states_val[layer].set_shape(
- tf.TensorShape([None, None, hidden_size]))
- states_key = [
- tile_to_beam_size(states_key[layer], beam_size)
- for layer in range(num_decoder_layers)
- ]
- states_val = [
- tile_to_beam_size(states_val[layer], beam_size)
- for layer in range(num_decoder_layers)
- ]
- fixed_length = 1
- if prefix is not None:
- init_seqs = tf.concat(
- [prefix, tf.fill([batch_size, beam_size, 1], 0)], axis=2)
- fixed_length = tf.shape(init_seqs)[-1]
- flat_seqs = merge_first_two_dims(init_seqs)
- flat_states_key = [
- merge_first_two_dims(states_key[layer])
- for layer in range(num_decoder_layers)
- ]
- flat_states_val = [
- merge_first_two_dims(states_val[layer])
- for layer in range(num_decoder_layers)
- ]
- step_log_probs, step_attn_weights, step_states_key, step_states_val = self.inference_func(
- encoder_output,
- feature_output,
- encoder_self_attention_bias,
- flat_seqs,
- flat_states_key,
- flat_states_val,
- params=params,
- is_prefix=True)
- states_key = [
- split_first_two_dims(step_states_key[layer], batch_size,
- beam_size)
- for layer in range(num_decoder_layers)
- ]
- states_val = [
- split_first_two_dims(step_states_val[layer], batch_size,
- beam_size)
- for layer in range(num_decoder_layers)
- ]
- prefix_hit = merge_first_two_dims(prefix_hit)
- log_probs = tf.where(
- prefix_hit, step_log_probs[:, -1, :],
- tf.ones_like(step_log_probs[:, -1, :]) * tf.float32.min)
- init_seqs = tf.concat([
- flat_seqs[:, :-1],
- tf.expand_dims(
- tf.cast(tf.argmax(log_probs, -1), tf.int32), -1)
- ], -1)
- init_seqs = split_first_two_dims(init_seqs, batch_size, beam_size)
- init_seqs = tf.concat(
- [init_seqs, tf.fill([batch_size, beam_size, 1], 0)], axis=2)
- else:
- init_seqs = tf.fill([batch_size, beam_size, 1], 0)
- init_log_probs = \
- tf.constant([[0.] + [tf.float32.min] * (beam_size - 1)])
- init_log_probs = tf.tile(init_log_probs, [batch_size, 1])
- init_scores = tf.zeros_like(init_log_probs)
- fin_seqs = init_seqs
- fin_scores = tf.fill([batch_size, beam_size], tf.float32.min)
- fin_flags = tf.cast(tf.fill([batch_size, beam_size], 0), tf.bool)
- state = BeamSearchState(
- inputs=(init_seqs, init_log_probs, init_scores),
- state=(states_key, states_val),
- finish=(fin_flags, fin_seqs, fin_scores),
- )
- def _beam_search_step(time, state):
- seqs, log_probs = state.inputs[:2]
- states_key, states_val = state.state
- flat_seqs = merge_first_two_dims(seqs)
- flat_states_key = [
- merge_first_two_dims(states_key[layer])
- for layer in range(num_decoder_layers)
- ]
- flat_states_val = [
- merge_first_two_dims(states_val[layer])
- for layer in range(num_decoder_layers)
- ]
- step_log_probs, step_attn_weights, step_states_key, step_states_val = self.inference_func(
- encoder_output,
- feature_output,
- encoder_self_attention_bias,
- flat_seqs,
- flat_states_key,
- flat_states_val,
- params=params,
- is_prefix=False)
- step_log_probs = split_first_two_dims(step_log_probs, batch_size,
- beam_size)
- curr_log_probs = tf.expand_dims(log_probs, 2) + step_log_probs
- next_states_key = [
- split_first_two_dims(step_states_key[layer], batch_size,
- beam_size)
- for layer in range(num_decoder_layers)
- ]
- next_states_val = [
- split_first_two_dims(step_states_val[layer], batch_size,
- beam_size)
- for layer in range(num_decoder_layers)
- ]
- # Apply length penalty
- length_penalty = tf.pow(
- (5.0 + tf.cast(time + 1, dtype=tf.float32)) / 6.0, lp_rate)
- curr_scores = curr_log_probs / length_penalty
- # Select top-k candidates
- # [batch_size, beam_size * vocab_size]
- curr_scores = tf.reshape(curr_scores,
- [-1, beam_size * trg_vocab_size])
- # [batch_size, 2 * beam_size]
- top_scores, top_indices = tf.nn.top_k(curr_scores, k=2 * beam_size)
- # Shape: [batch_size, 2 * beam_size]
- beam_indices = top_indices // trg_vocab_size
- symbol_indices = top_indices % trg_vocab_size
- # Expand sequences
- # [batch_size, 2 * beam_size, time]
- candidate_seqs = gather_2d(seqs, beam_indices)
- candidate_seqs = tf.concat(
- [candidate_seqs[:, :, :-1],
- tf.expand_dims(symbol_indices, 2)],
- axis=2)
- pad_seqs = tf.fill([batch_size, 2 * beam_size, 1],
- tf.constant(0, tf.int32))
- candidate_seqs = tf.concat([candidate_seqs, pad_seqs], axis=2)
- # Expand sequences
- # Suppress finished sequences
- flags = tf.equal(symbol_indices, 0)
- # [batch, 2 * beam_size]
- alive_scores = top_scores + tf.cast(
- flags, dtype=tf.float32) * tf.float32.min
- # [batch, beam_size]
- alive_scores, alive_indices = tf.nn.top_k(alive_scores, beam_size)
- alive_symbols = gather_2d(symbol_indices, alive_indices)
- alive_indices = gather_2d(beam_indices, alive_indices)
- alive_seqs = gather_2d(seqs, alive_indices)
- alive_seqs = tf.concat(
- [alive_seqs[:, :, :-1],
- tf.expand_dims(alive_symbols, 2)],
- axis=2)
- pad_seqs = tf.fill([batch_size, beam_size, 1],
- tf.constant(0, tf.int32))
- alive_seqs = tf.concat([alive_seqs, pad_seqs], axis=2)
- alive_states_key = [
- gather_2d(next_states_key[layer], alive_indices)
- for layer in range(num_decoder_layers)
- ]
- alive_states_val = [
- gather_2d(next_states_val[layer], alive_indices)
- for layer in range(num_decoder_layers)
- ]
- alive_log_probs = alive_scores * length_penalty
- # Select finished sequences
- prev_fin_flags, prev_fin_seqs, prev_fin_scores = state.finish
- # [batch, 2 * beam_size]
- step_fin_scores = top_scores + (
- 1.0 - tf.cast(flags, dtype=tf.float32)) * tf.float32.min
- # [batch, 3 * beam_size]
- fin_flags = tf.concat([prev_fin_flags, flags], axis=1)
- fin_scores = tf.concat([prev_fin_scores, step_fin_scores], axis=1)
- # [batch, beam_size]
- fin_scores, fin_indices = tf.nn.top_k(fin_scores, beam_size)
- fin_flags = gather_2d(fin_flags, fin_indices)
- pad_seqs = tf.fill([batch_size, beam_size, 1],
- tf.constant(0, tf.int32))
- prev_fin_seqs = tf.concat([prev_fin_seqs, pad_seqs], axis=2)
- fin_seqs = tf.concat([prev_fin_seqs, candidate_seqs], axis=1)
- fin_seqs = gather_2d(fin_seqs, fin_indices)
- new_state = BeamSearchState(
- inputs=(alive_seqs, alive_log_probs, alive_scores),
- state=(alive_states_key, alive_states_val),
- finish=(fin_flags, fin_seqs, fin_scores),
- )
- return time + 1, new_state
- def _is_finished(t, s):
- log_probs = s.inputs[1]
- finished_flags = s.finish[0]
- finished_scores = s.finish[2]
- max_lp = tf.pow(
- ((5.0 + tf.cast(max_decoded_trg_len, dtype=tf.float32)) / 6.0),
- lp_rate)
- best_alive_score = log_probs[:, 0] / max_lp
- worst_finished_score = tf.reduce_min(
- input_tensor=finished_scores
- * tf.cast(finished_flags, dtype=tf.float32),
- axis=1)
- add_mask = 1.0 - tf.cast(
- tf.reduce_any(input_tensor=finished_flags, axis=1),
- dtype=tf.float32)
- worst_finished_score += tf.float32.min * add_mask
- bound_is_met = tf.reduce_all(
- input_tensor=tf.greater(worst_finished_score,
- best_alive_score))
- cond = tf.logical_and(
- tf.less(t, max_decoded_trg_len), tf.logical_not(bound_is_met))
- return cond
- def _loop_fn(t, s):
- outs = _beam_search_step(t, s)
- return outs
- time = tf.constant(0, name='time')
- shape_invariants = BeamSearchState(
- inputs=(tf.TensorShape([None, None, None]),
- tf.TensorShape([None, None]), tf.TensorShape([None,
- None])),
- state=([
- tf.TensorShape([None, None, None, hidden_size])
- for layer in range(num_decoder_layers)
- ], [
- tf.TensorShape([None, None, None, hidden_size])
- for layer in range(num_decoder_layers)
- ]),
- finish=(tf.TensorShape([None,
- None]), tf.TensorShape([None, None, None]),
- tf.TensorShape([None, None])))
- outputs = tf.while_loop(
- cond=_is_finished,
- body=_loop_fn,
- loop_vars=[time, state],
- shape_invariants=[tf.TensorShape([]), shape_invariants],
- parallel_iterations=1,
- back_prop=False)
- final_state = outputs[1]
- alive_seqs = final_state.inputs[0]
- alive_scores = final_state.inputs[2]
- final_flags = final_state.finish[0]
- final_seqs = final_state.finish[1]
- final_scores = final_state.finish[2]
- alive_seqs.set_shape([None, beam_size, None])
- final_seqs.set_shape([None, beam_size, None])
- final_seqs = tf.compat.v1.where(
- tf.reduce_any(input_tensor=final_flags, axis=1), final_seqs,
- alive_seqs)
- final_scores = tf.compat.v1.where(
- tf.reduce_any(input_tensor=final_flags, axis=1), final_scores,
- alive_scores)
- final_seqs = final_seqs[:, :, fixed_length - 1:-1]
- return final_seqs, final_scores
- class BeamSearchState(
- namedtuple('BeamSearchState', ('inputs', 'state', 'finish'))):
- pass
- def tile_to_beam_size(tensor, beam_size):
- """Tiles a given tensor by beam_size. """
- tensor = tf.expand_dims(tensor, axis=1)
- tile_dims = [1] * tensor.shape.ndims
- tile_dims[1] = beam_size
- return tf.tile(tensor, tile_dims)
- def infer_shape(x):
- x = tf.convert_to_tensor(x)
- if x.shape.dims is None:
- return tf.shape(x)
- static_shape = x.shape.as_list()
- dynamic_shape = tf.shape(x)
- ret = []
- for i in range(len(static_shape)):
- dim = static_shape[i]
- if dim is None:
- dim = dynamic_shape[i]
- ret.append(dim)
- return ret
- def split_first_two_dims(tensor, dim_0, dim_1):
- shape = infer_shape(tensor)
- new_shape = [dim_0] + [dim_1] + shape[1:]
- return tf.reshape(tensor, new_shape)
- def merge_first_two_dims(tensor):
- shape = infer_shape(tensor)
- shape[0] *= shape[1]
- shape.pop(1)
- return tf.reshape(tensor, shape)
- def gather_2d(params, indices, name=None):
- """ Gather the 2nd dimension given indices
- :param params: A tensor with shape [batch_size, M, ...]
- :param indices: A tensor with shape [batch_size, N]
- :param name: An optional string
- :return: A tensor with shape [batch_size, N, ...]
- """
- batch_size = tf.shape(params)[0]
- range_size = tf.shape(indices)[1]
- batch_pos = tf.range(batch_size * range_size) // range_size
- batch_pos = tf.reshape(batch_pos, [batch_size, range_size])
- indices = tf.stack([batch_pos, indices], axis=-1)
- output = tf.gather_nd(params, indices, name=name)
- return output
- def linear(inputs, output_size, bias, concat=True, dtype=None, scope=None):
- with tf.compat.v1.variable_scope(
- scope, default_name='linear', values=[inputs], dtype=dtype):
- if not isinstance(inputs, (list, tuple)):
- inputs = [inputs]
- input_size = [item.get_shape()[-1] for item in inputs]
- if len(inputs) != len(input_size):
- raise RuntimeError('inputs and input_size unmatched!')
- output_shape = tf.concat([tf.shape(inputs[0])[:-1], [output_size]],
- axis=0)
- # Flatten to 2D
- inputs = [tf.reshape(inp, [-1, inp.shape[-1]]) for inp in inputs]
- results = []
- if concat:
- input_size = sum(input_size)
- inputs = tf.concat(inputs, 1)
- shape = [input_size, output_size]
- matrix = tf.compat.v1.get_variable('matrix', shape)
- results.append(tf.matmul(inputs, matrix))
- else:
- for i in range(len(input_size)):
- shape = [input_size[i], output_size]
- name = 'matrix_%d' % i
- matrix = tf.compat.v1.get_variable(name, shape)
- results.append(tf.matmul(inputs[i], matrix))
- output = tf.add_n(results)
- if bias:
- shape = [output_size]
- bias = tf.compat.v1.get_variable('bias', shape)
- output = tf.nn.bias_add(output, bias)
- output = tf.reshape(output, output_shape)
- return output
- def layer_norm(inputs, epsilon=1e-6, name=None, reuse=None):
- with tf.compat.v1.variable_scope(
- name, default_name='layer_norm', values=[inputs], reuse=reuse):
- channel_size = inputs.get_shape().as_list()[-1]
- scale = tf.compat.v1.get_variable(
- 'layer_norm_scale', [channel_size],
- initializer=tf.ones_initializer())
- offset = tf.compat.v1.get_variable(
- 'layer_norm_offset', [channel_size],
- initializer=tf.zeros_initializer())
- mean = tf.reduce_mean(inputs, -1, True)
- variance = tf.reduce_mean(tf.square(inputs - mean), -1, True)
- norm_inputs = (inputs - mean) * tf.compat.v1.rsqrt(variance + epsilon)
- return norm_inputs * scale + offset
- def _layer_process(x, mode):
- if not mode or mode == 'none':
- return x
- elif mode == 'layer_norm':
- return layer_norm(x)
- else:
- raise ValueError('Unknown mode %s' % mode)
- def _residual_fn(x, y, keep_prob=None):
- if keep_prob and keep_prob < 1.0:
- y = tf.nn.dropout(y, rate=1 - (keep_prob))
- return x + y
- def embedding_augmentation_layer(x, embedding_augmentation, params, name=None):
- hidden_size = params['hidden_size']
- keep_prob = 1.0 - params['relu_dropout']
- with tf.compat.v1.variable_scope(
- name,
- default_name='embedding_augmentation_layer',
- values=[x, embedding_augmentation]):
- with tf.compat.v1.variable_scope('input_layer'):
- hidden = linear(embedding_augmentation, hidden_size, True, True)
- hidden = tf.nn.relu(hidden)
- if keep_prob and keep_prob < 1.0:
- hidden = tf.nn.dropout(hidden, rate=1 - (keep_prob))
- with tf.compat.v1.variable_scope('output_layer'):
- output = linear(hidden, hidden_size, True, True)
- return x + output
- def transformer_ffn_layer(x, params, name=None):
- filter_size = params['filter_size']
- hidden_size = params['hidden_size']
- keep_prob = 1.0 - params['relu_dropout']
- with tf.compat.v1.variable_scope(
- name, default_name='ffn_layer', values=[x]):
- with tf.compat.v1.variable_scope('input_layer'):
- hidden = linear(x, filter_size, True, True)
- hidden = tf.nn.relu(hidden)
- if keep_prob and keep_prob < 1.0:
- hidden = tf.nn.dropout(hidden, rate=1 - (keep_prob))
- with tf.compat.v1.variable_scope('output_layer'):
- output = linear(hidden, hidden_size, True, True)
- return output
- def transformer_encoder(encoder_input,
- encoder_self_attention_bias,
- mask,
- params={},
- name='encoder'):
- num_encoder_layers = params['num_encoder_layers']
- hidden_size = params['hidden_size']
- num_heads = params['num_heads']
- residual_dropout = params['residual_dropout']
- attention_dropout = params['attention_dropout']
- layer_preproc = params['layer_preproc']
- layer_postproc = params['layer_postproc']
- x = encoder_input
- mask = tf.expand_dims(mask, 2)
- with tf.compat.v1.variable_scope(name, reuse=tf.compat.v1.AUTO_REUSE):
- for layer in range(num_encoder_layers):
- with tf.compat.v1.variable_scope('layer_%d' % layer):
- max_relative_dis = params['max_relative_dis'] \
- if params['position_info_type'] == 'relative' else None
- o, w = multihead_attention(
- _layer_process(x, layer_preproc),
- None,
- encoder_self_attention_bias,
- hidden_size,
- hidden_size,
- hidden_size,
- num_heads,
- attention_dropout,
- max_relative_dis=max_relative_dis,
- name='encoder_self_attention')
- x = _residual_fn(x, o, 1.0 - residual_dropout)
- x = _layer_process(x, layer_postproc)
- o = transformer_ffn_layer(
- _layer_process(x, layer_preproc), params)
- x = _residual_fn(x, o, 1.0 - residual_dropout)
- x = _layer_process(x, layer_postproc)
- x = tf.multiply(x, mask)
- return _layer_process(x, layer_preproc)
- def transformer_semantic_encoder(encoder_input,
- encoder_self_attention_bias,
- mask,
- params={},
- name='mini_xlm_encoder'):
- num_encoder_layers = params['num_semantic_encoder_layers']
- hidden_size = params['hidden_size']
- num_heads = params['num_heads']
- residual_dropout = params['residual_dropout']
- attention_dropout = params['attention_dropout']
- layer_preproc = params['layer_preproc']
- layer_postproc = params['layer_postproc']
- x = encoder_input
- mask = tf.expand_dims(mask, 2)
- with tf.compat.v1.variable_scope(name, reuse=tf.compat.v1.AUTO_REUSE):
- for layer in range(num_encoder_layers):
- with tf.compat.v1.variable_scope('layer_%d' % layer):
- max_relative_dis = params['max_relative_dis']
- o, w = multihead_attention(
- _layer_process(x, layer_preproc),
- None,
- encoder_self_attention_bias,
- hidden_size,
- hidden_size,
- hidden_size,
- num_heads,
- attention_dropout,
- max_relative_dis=max_relative_dis,
- name='encoder_self_attention')
- x = _residual_fn(x, o, 1.0 - residual_dropout)
- x = _layer_process(x, layer_postproc)
- o = transformer_ffn_layer(
- _layer_process(x, layer_preproc), params)
- x = _residual_fn(x, o, 1.0 - residual_dropout)
- x = _layer_process(x, layer_postproc)
- x = tf.multiply(x, mask)
- with tf.compat.v1.variable_scope(
- 'pooling_layer', reuse=tf.compat.v1.AUTO_REUSE):
- output = tf.reduce_sum(
- input_tensor=x, axis=1) / tf.reduce_sum(
- input_tensor=mask, axis=1)
- output = linear(
- tf.expand_dims(output, axis=1), hidden_size, True, True)
- return _layer_process(output, layer_preproc)
- def transformer_decoder(decoder_input,
- encoder_output,
- decoder_self_attention_bias,
- encoder_decoder_attention_bias,
- states_key=None,
- states_val=None,
- embedding_augmentation=None,
- params={},
- name='decoder'):
- num_decoder_layers = params['num_decoder_layers']
- hidden_size = params['hidden_size']
- num_heads = params['num_heads']
- residual_dropout = params['residual_dropout']
- attention_dropout = params['attention_dropout']
- layer_preproc = params['layer_preproc']
- layer_postproc = params['layer_postproc']
- x = decoder_input
- with tf.compat.v1.variable_scope(name, reuse=tf.compat.v1.AUTO_REUSE):
- for layer in range(num_decoder_layers):
- with tf.compat.v1.variable_scope('layer_%d' % layer):
- max_relative_dis = params['max_relative_dis'] \
- if params['position_info_type'] == 'relative' else None
- # continuous semantic augmentation
- if embedding_augmentation is not None:
- x = embedding_augmentation_layer(
- x, _layer_process(embedding_augmentation,
- layer_preproc), params)
- x = _layer_process(x, layer_postproc)
- o, w = multihead_attention(
- _layer_process(x, layer_preproc),
- None,
- decoder_self_attention_bias,
- hidden_size,
- hidden_size,
- hidden_size,
- num_heads,
- attention_dropout,
- states_key=states_key,
- states_val=states_val,
- layer=layer,
- max_relative_dis=max_relative_dis,
- name='decoder_self_attention')
- x = _residual_fn(x, o, 1.0 - residual_dropout)
- x = _layer_process(x, layer_postproc)
- o, w = multihead_attention(
- _layer_process(x, layer_preproc),
- encoder_output,
- encoder_decoder_attention_bias,
- hidden_size,
- hidden_size,
- hidden_size,
- num_heads,
- attention_dropout,
- max_relative_dis=max_relative_dis,
- name='encdec_attention')
- x = _residual_fn(x, o, 1.0 - residual_dropout)
- x = _layer_process(x, layer_postproc)
- o = transformer_ffn_layer(
- _layer_process(x, layer_preproc), params)
- x = _residual_fn(x, o, 1.0 - residual_dropout)
- x = _layer_process(x, layer_postproc)
- return _layer_process(x, layer_preproc), w
- def add_timing_signal(x, min_timescale=1.0, max_timescale=1.0e4):
- length = tf.shape(x)[1]
- channels = tf.shape(x)[2]
- position = tf.cast(tf.range(length), tf.float32)
- num_timescales = channels // 2
- log_timescale_increment = \
- (math.log(float(max_timescale) / float(min_timescale)) / (tf.cast(num_timescales, tf.float32) - 1))
- inv_timescales = min_timescale * tf.exp(
- tf.cast(tf.range(num_timescales), tf.float32)
- * -log_timescale_increment)
- scaled_time = \
- tf.expand_dims(position, 1) * tf.expand_dims(inv_timescales, 0)
- signal = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=1)
- signal = tf.pad(signal, [[0, 0], [0, tf.compat.v1.mod(channels, 2)]])
- signal = tf.reshape(signal, [1, length, channels])
- return x + tf.cast(signal, x.dtype)
- def attention_bias(inputs, mode, inf=-1e9, dtype=None):
- if dtype is None:
- dtype = tf.float32
- if dtype != tf.float32:
- inf = dtype.min
- if mode == 'masking':
- mask = inputs
- ret = (1.0 - mask) * inf
- ret = tf.expand_dims(tf.expand_dims(ret, 1), 1)
- elif mode == 'causal':
- length = inputs
- lower_triangle = tf.linalg.band_part(
- tf.fill([length, length], 1.0), -1, 0)
- ret = inf * (1.0 - lower_triangle)
- ret = tf.reshape(ret, [1, 1, length, length])
- else:
- raise ValueError('Unknown mode %s' % mode)
- return tf.cast(ret, dtype)
- def split_heads(x, num_heads):
- n = num_heads
- old_shape = x.get_shape().dims
- ndims = x.shape.ndims
- last = old_shape[-1]
- new_shape = old_shape[:-1] + [n] + [last // n if last else None]
- ret = tf.reshape(x, tf.concat([tf.shape(x)[:-1], [n, -1]], 0))
- ret.set_shape(new_shape)
- perm = [0, ndims - 1] + [i for i in range(1, ndims - 1)] + [ndims]
- return tf.transpose(ret, perm)
- def dot_product_attention(q,
- k,
- v,
- bias,
- dropout_rate=0.0,
- name=None,
- rpr=None):
- with tf.compat.v1.variable_scope(
- name, default_name='dot_product_attention', values=[q, k, v]):
- q_shape = tf.shape(q)
- bs, hd, lq, dk = q_shape[0], q_shape[1], q_shape[2], q_shape[3]
- lk = tf.shape(k)[2]
- dv = tf.shape(v)[3]
- if rpr is not None:
- rpr_k, rpr_v = rpr['rpr_k'], rpr[
- 'rpr_v'] # (lq, lk, dk), (lq, lk, dv)
- if rpr is None:
- logits = tf.matmul(q, k, transpose_b=True)
- else: # self-attention with relative position representation
- logits_part1 = tf.matmul(q, k, transpose_b=True) # bs, hd, lq, lk
- q = tf.reshape(tf.transpose(q, [2, 0, 1, 3]),
- [lq, bs * hd, dk]) # lq, bs*hd, dk
- logits_part2 = tf.matmul(q,
- tf.transpose(rpr_k,
- [0, 2, 1])) # lq, bs*hd, lk
- logits_part2 = tf.reshape(
- tf.transpose(logits_part2, [1, 0, 2]), [bs, hd, lq, lk])
- logits = logits_part1 + logits_part2 # bs, hd, lq, lk
- if bias is not None:
- logits += bias
- weights = tf.nn.softmax(logits, name='attention_weights')
- if dropout_rate > 0.0:
- weights = tf.nn.dropout(weights, 1.0 - dropout_rate)
- if rpr is None:
- return tf.matmul(weights, v), weights
- else:
- outputs_part1 = tf.matmul(weights, v) # bs, hd, lq, dv
- weights = tf.reshape(
- tf.transpose(weights, [2, 0, 1, 3]),
- [lq, bs * hd, lk]) # lq, bs*hd, lk
- outputs_part2 = tf.matmul(weights, rpr_v) # lq, bs*hd, dv
- outputs_part2 = tf.reshape(
- tf.transpose(outputs_part2, [1, 0, 2]), [bs, hd, lq, dv])
- outputs = outputs_part1 + outputs_part2 # bs, hd, lq, dv
- weights = tf.reshape(
- tf.transpose(weights, [1, 0, 2]),
- [bs, hd, lq, lk]) # bs, hd, lq, lk
- return outputs, weights
- def combine_heads(x):
- x = tf.transpose(x, [0, 2, 1, 3])
- old_shape = x.get_shape().dims
- a, b = old_shape[-2:]
- new_shape = old_shape[:-2] + [a * b if a and b else None]
- x = tf.reshape(x, tf.concat([tf.shape(x)[:-2], [-1]], 0))
- x.set_shape(new_shape)
- return x
- def create_rpr(orginal_var,
- length_q,
- length_kv,
- max_relative_dis,
- name='create_rpr'):
- with tf.name_scope(name):
- idxs = tf.reshape(tf.range(length_kv), [-1, 1]) # only self-attention
- idys = tf.reshape(tf.range(length_kv), [1, -1])
- ids = idxs - idys
- ids = ids + max_relative_dis
- ids = tf.maximum(ids, 0)
- ids = tf.minimum(ids, 2 * max_relative_dis)
- ids = ids[-length_q:, :]
- rpr = tf.gather(orginal_var, ids)
- return rpr
- def multihead_attention(queries,
- memories,
- bias,
- key_depth,
- value_depth,
- output_depth,
- num_heads,
- dropout_rate,
- states_key=None,
- states_val=None,
- layer=0,
- max_relative_dis=None,
- name=None):
- if key_depth % num_heads != 0:
- raise ValueError(
- 'Key size (%d) must be divisible by the number of attention heads (%d).'
- % (key_size, num_heads))
- if value_depth % num_heads != 0:
- raise ValueError(
- 'Value size (%d) must be divisible by the number of attention heads (%d).'
- % (value_size, num_heads))
- with tf.compat.v1.variable_scope(
- name, default_name='multihead_attention',
- values=[queries, memories]):
- if memories is None:
- # self attention
- combined = linear(
- queries,
- key_depth * 2 + value_depth,
- True,
- True,
- scope='qkv_transform')
- q, k, v = tf.split(
- combined, [key_depth, key_depth, value_depth], axis=2)
- else:
- q = linear(queries, key_depth, True, True, scope='q_transform')
- combined = linear(
- memories,
- key_depth + value_depth,
- True,
- True,
- scope='kv_transform')
- k, v = tf.split(combined, [key_depth, value_depth], axis=2)
- if states_key is not None:
- k = states_key[layer] = tf.concat([states_key[layer], k], axis=1)
- if states_val is not None:
- v = states_val[layer] = tf.concat([states_val[layer], v], axis=1)
- q = split_heads(q, num_heads)
- k = split_heads(k, num_heads)
- v = split_heads(v, num_heads)
- key_depth_per_head = key_depth // num_heads
- q *= key_depth_per_head**-0.5
- length_q = tf.shape(q)[2]
- length_kv = tf.shape(k)[2]
- # relative position representation (only in self-attention)
- if memories is None and max_relative_dis is not None:
- rpr_k = tf.compat.v1.get_variable(
- 'rpr_k', [2 * max_relative_dis + 1, key_depth // num_heads])
- rpr_v = tf.compat.v1.get_variable(
- 'rpr_v', [2 * max_relative_dis + 1, value_depth // num_heads])
- rpr_k = create_rpr(rpr_k, length_q, length_kv, max_relative_dis)
- rpr_v = create_rpr(rpr_v, length_q, length_kv, max_relative_dis)
- rpr = {'rpr_k': rpr_k, 'rpr_v': rpr_v}
- x, w = dot_product_attention(q, k, v, bias, dropout_rate, rpr=rpr)
- else:
- x, w = dot_product_attention(q, k, v, bias, dropout_rate)
- x = combine_heads(x)
- w = tf.reduce_mean(w, 1)
- x = linear(x, output_depth, True, True, scope='output_transform')
- return x, w
- def get_initializer(params):
- if params['initializer'] == 'uniform':
- max_val = params['initializer_scale']
- return tf.compat.v1.random_uniform_initializer(-max_val, max_val)
- elif params['initializer'] == 'normal':
- return tf.compat.v1.random_normal_initializer(
- 0.0, params['initializer_scale'])
- elif params['initializer'] == 'normal_unit_scaling':
- return tf.compat.v1.variance_scaling_initializer(
- params['initializer_scale'], mode='fan_avg', distribution='normal')
- elif params['initializer'] == 'uniform_unit_scaling':
- return tf.compat.v1.variance_scaling_initializer(
- params['initializer_scale'],
- mode='fan_avg',
- distribution='uniform')
- else:
- raise ValueError('Unrecognized initializer: %s'
- % params['initializer'])
- def get_learning_rate_decay(learning_rate, global_step, params):
- if params['learning_rate_decay'] in ['linear_warmup_rsqrt_decay', 'noam']:
- step = tf.cast(global_step, dtype=tf.float32)
- warmup_steps = tf.cast(params['warmup_steps'], dtype=tf.float32)
- multiplier = params['hidden_size']**-0.5
- decay = multiplier * tf.minimum((step + 1) * (warmup_steps**-1.5),
- (step + 1)**-0.5)
- return learning_rate * decay
- elif params['learning_rate_decay'] == 'piecewise_constant':
- return tf.compat.v1.train.piecewise_constant(
- tf.cast(global_step, dtype=tf.int32),
- params['learning_rate_boundaries'], params['learning_rate_values'])
- elif params['learning_rate_decay'] == 'none':
- return learning_rate
- else:
- raise ValueError('Unknown learning_rate_decay')
- def average_gradients(tower_grads):
- average_grads = []
- for grad_and_vars in zip(*tower_grads):
- grads = []
- for g, _ in grad_and_vars:
- expanded_g = tf.expand_dims(g, 0)
- grads.append(expanded_g)
- grad = tf.concat(axis=0, values=grads)
- grad = tf.reduce_mean(grad, 0)
- v = grad_and_vars[0][1]
- grad_and_var = (grad, v)
- average_grads.append(grad_and_var)
- return average_grads
- _ENGINE = None
- def all_reduce(tensor):
- if _ENGINE is None:
- return tensor
- return _ENGINE.allreduce(tensor, compression=_ENGINE.Compression.fp16)
- class MultiStepOptimizer(tf.compat.v1.train.Optimizer):
- def __init__(self,
- optimizer,
- step=1,
- use_locking=False,
- name='MultiStepOptimizer'):
- super(MultiStepOptimizer, self).__init__(use_locking, name)
- self._optimizer = optimizer
- self._step = step
- self._step_t = tf.convert_to_tensor(step, name='step')
- def _all_reduce(self, tensor):
- with tf.name_scope(self._name + '_Allreduce'):
- if tensor is None:
- return tensor
- if isinstance(tensor, tf.IndexedSlices):
- tensor = tf.convert_to_tensor(tensor)
- return all_reduce(tensor)
- def compute_gradients(self,
- loss,
- var_list=None,
- gate_gradients=tf.compat.v1.train.Optimizer.GATE_OP,
- aggregation_method=None,
- colocate_gradients_with_ops=False,
- grad_loss=None):
- grads_and_vars = self._optimizer.compute_gradients(
- loss, var_list, gate_gradients, aggregation_method,
- colocate_gradients_with_ops, grad_loss)
- grads, var_list = list(zip(*grads_and_vars))
- # Do not create extra variables when step is 1
- if self._step == 1:
- grads = [self._all_reduce(t) for t in grads]
- return list(zip(grads, var_list))
- first_var = min(var_list, key=lambda x: x.name)
- iter_var = self._create_non_slot_variable(
- initial_value=0 if self._step == 1 else 1,
- name='iter',
- colocate_with=first_var)
- new_grads = []
- for grad, var in zip(grads, var_list):
- grad_acc = self._zeros_slot(var, 'grad_acc', self._name)
- if isinstance(grad, tf.IndexedSlices):
- grad_acc = tf.scatter_add(
- grad_acc,
- grad.indices,
- grad.values,
- use_locking=self._use_locking)
- else:
- grad_acc = tf.assign_add(
- grad_acc, grad, use_locking=self._use_locking)
- def _acc_grad():
- return grad_acc
- def _avg_grad():
- return self._all_reduce(grad_acc / self._step)
- grad = tf.cond(tf.equal(iter_var, 0), _avg_grad, _acc_grad)
- new_grads.append(grad)
- return list(zip(new_grads, var_list))
- def apply_gradients(self, grads_and_vars, global_step=None, name=None):
- if self._step == 1:
- return self._optimizer.apply_gradients(
- grads_and_vars, global_step, name=name)
- grads, var_list = list(zip(*grads_and_vars))
- def _pass_gradients():
- return tf.group(*grads)
- def _apply_gradients():
- op = self._optimizer.apply_gradients(
- zip(grads, var_list), global_step, name)
- with tf.control_dependencies([op]):
- zero_ops = []
- for var in var_list:
- grad_acc = self.get_slot(var, 'grad_acc')
- zero_ops.append(
- grad_acc.assign(
- tf.zeros_like(grad_acc),
- use_locking=self._use_locking))
- zero_op = tf.group(*zero_ops)
- return tf.group(*[op, zero_op])
- iter_var = self._get_non_slot_variable('iter', tf.get_default_graph())
- update_op = tf.cond(
- tf.equal(iter_var, 0), _apply_gradients, _pass_gradients)
- with tf.control_dependencies([update_op]):
- iter_op = iter_var.assign(
- tf.mod(iter_var + 1, self._step_t),
- use_locking=self._use_locking)
- return tf.group(*[update_op, iter_op])
- def shard_features(x, num_datashards):
- x = tf.convert_to_tensor(x)
- batch_size = tf.shape(x)[0]
- size_splits = []
- with tf.device('/cpu:0'):
- for i in range(num_datashards):
- size_splits.append(
- tf.cond(
- tf.greater(
- tf.compat.v1.mod(batch_size, num_datashards),
- i), lambda: batch_size // num_datashards + 1,
- lambda: batch_size // num_datashards))
- return tf.split(x, size_splits, axis=0)
|