translation.py 60 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574
  1. # Part of the implementation is borrowed and modified from THUMT,
  2. # publicly available at https://github.com/THUNLP-MT/THUMT
  3. # Copyright 2017-2022 The Alibaba MT Team Authors. All rights reserved.
  4. import math
  5. from collections import namedtuple
  6. from typing import Dict
  7. import tensorflow as tf
  8. from modelscope.metainfo import Models
  9. from modelscope.models.base import Model, Tensor
  10. from modelscope.models.builder import MODELS
  11. from modelscope.utils.constant import Tasks
  12. __all__ = ['CsanmtForTranslation']
  13. @MODELS.register_module(Tasks.translation, module_name=Models.translation)
  14. class CsanmtForTranslation(Model):
  15. def __init__(self, model_dir, *args, **kwargs):
  16. """
  17. Args:
  18. params (dict): the model configuration.
  19. """
  20. super().__init__(model_dir, *args, **kwargs)
  21. self.params = kwargs
  22. print(self.params)
  23. def __call__(self,
  24. input: Dict[str, Tensor],
  25. label: Dict[str, Tensor] = None,
  26. prefix: Dict[str, Tensor] = None,
  27. prefix_hit: Dict[bool, Tensor] = None) -> Dict[str, Tensor]:
  28. """return the result by the model
  29. Args:
  30. input: the preprocessed input source sequence
  31. label: the ground truth target data for model training
  32. prefix: the preprocessed input target prefix sequence for interactive translation
  33. prefix_hit: the preprocessed target prefix subword vector for interactive translation
  34. Returns:
  35. output_seqs: output sequence of target ids
  36. """
  37. if label is None:
  38. with tf.compat.v1.variable_scope('NmtModel'):
  39. output_seqs, output_scores = self.beam_search(
  40. {
  41. 'input_wids': input,
  42. 'prefix_wids': prefix,
  43. 'prefix_hit': prefix_hit
  44. }, self.params)
  45. return {
  46. 'output_seqs': output_seqs,
  47. 'output_scores': output_scores,
  48. }
  49. else:
  50. train_op, loss = self.transformer_model_train_fn(input, label)
  51. return {
  52. 'train_op': train_op,
  53. 'loss': loss,
  54. }
  55. def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
  56. """
  57. Run the forward pass for a model.
  58. Args:
  59. input (Dict[str, Tensor]): the dict of the model inputs for the forward method
  60. Returns:
  61. Dict[str, Tensor]: output from the model forward pass
  62. """
  63. ...
  64. def encoding_graph(self, features, params):
  65. src_vocab_size = params['src_vocab_size']
  66. hidden_size = params['hidden_size']
  67. initializer = tf.compat.v1.random_normal_initializer(
  68. 0.0, hidden_size**-0.5, dtype=tf.float32)
  69. if params['shared_source_target_embedding']:
  70. with tf.compat.v1.variable_scope(
  71. 'Shared_Embedding', reuse=tf.compat.v1.AUTO_REUSE):
  72. src_embedding = tf.compat.v1.get_variable(
  73. 'Weights', [src_vocab_size, hidden_size],
  74. initializer=initializer)
  75. else:
  76. with tf.compat.v1.variable_scope('Source_Embedding'):
  77. src_embedding = tf.compat.v1.get_variable(
  78. 'Weights', [src_vocab_size, hidden_size],
  79. initializer=initializer)
  80. src_bias = tf.compat.v1.get_variable('encoder_input_bias',
  81. [hidden_size])
  82. eos_padding = tf.zeros_like(features, dtype=tf.int64)[:, :1]
  83. src_seq = tf.concat([features, eos_padding], 1)
  84. src_mask = tf.cast(tf.not_equal(src_seq, 0), dtype=tf.float32)
  85. shift_src_mask = src_mask[:, :-1]
  86. shift_src_mask = tf.pad(
  87. tensor=shift_src_mask,
  88. paddings=[[0, 0], [1, 0]],
  89. constant_values=1)
  90. encoder_input = tf.gather(src_embedding, tf.cast(src_seq, tf.int32))
  91. encoder_input = encoder_input * (hidden_size**0.5)
  92. if params['position_info_type'] == 'absolute':
  93. encoder_input = add_timing_signal(encoder_input)
  94. encoder_input = tf.multiply(encoder_input,
  95. tf.expand_dims(shift_src_mask, 2))
  96. encoder_input = tf.nn.bias_add(encoder_input, src_bias)
  97. encoder_self_attention_bias = attention_bias(shift_src_mask, 'masking')
  98. if params['residual_dropout'] > 0.0:
  99. encoder_input = tf.nn.dropout(
  100. encoder_input, rate=params['residual_dropout'])
  101. # encode
  102. encoder_output = transformer_encoder(encoder_input,
  103. encoder_self_attention_bias,
  104. shift_src_mask, params)
  105. return encoder_output, encoder_self_attention_bias
  106. def semantic_encoding_graph(self, features, params, name=None):
  107. hidden_size = params['hidden_size']
  108. initializer = tf.compat.v1.random_normal_initializer(
  109. 0.0, hidden_size**-0.5, dtype=tf.float32)
  110. scope = None
  111. if params['shared_source_target_embedding']:
  112. vocab_size = params['src_vocab_size']
  113. scope = 'Shared_Semantic_Embedding'
  114. elif name == 'source':
  115. vocab_size = params['src_vocab_size']
  116. scope = 'Source_Semantic_Embedding'
  117. elif name == 'target':
  118. vocab_size = params['trg_vocab_size']
  119. scope = 'Target_Semantic_Embedding'
  120. else:
  121. raise ValueError('error: no right name specified.')
  122. with tf.compat.v1.variable_scope(scope, reuse=tf.compat.v1.AUTO_REUSE):
  123. embedding_mat = tf.compat.v1.get_variable(
  124. 'Weights', [vocab_size, hidden_size], initializer=initializer)
  125. eos_padding = tf.zeros_like(features, dtype=tf.int64)[:, :1]
  126. input_seq = tf.concat([features, eos_padding], 1)
  127. input_mask = tf.cast(tf.not_equal(input_seq, 0), dtype=tf.float32)
  128. shift_input_mask = input_mask[:, :-1]
  129. shift_input_mask = tf.pad(
  130. tensor=shift_input_mask,
  131. paddings=[[0, 0], [1, 0]],
  132. constant_values=1)
  133. encoder_input = tf.gather(embedding_mat, tf.cast(input_seq, tf.int32))
  134. encoder_input = encoder_input * (hidden_size**0.5)
  135. encoder_input = tf.multiply(encoder_input,
  136. tf.expand_dims(shift_input_mask, 2))
  137. encoder_self_attention_bias = attention_bias(shift_input_mask,
  138. 'masking')
  139. if params['residual_dropout'] > 0.0:
  140. encoder_input = tf.nn.dropout(
  141. encoder_input, rate=params['residual_dropout'])
  142. # encode
  143. encoder_output = transformer_semantic_encoder(
  144. encoder_input, encoder_self_attention_bias, shift_input_mask,
  145. params)
  146. return encoder_output
  147. def build_contrastive_training_graph(self, features, labels, params):
  148. # representations
  149. source_name = 'source'
  150. target_name = 'target'
  151. if params['shared_source_target_embedding']:
  152. source_name = None
  153. target_name = None
  154. feature_output = self.semantic_encoding_graph(
  155. features, params, name=source_name)
  156. label_output = self.semantic_encoding_graph(
  157. labels, params, name=target_name)
  158. return feature_output, label_output
  159. def MGMC_sampling(self, x_embedding, y_embedding, params, epsilon=1e-12):
  160. K = params['num_of_samples']
  161. eta = params['eta']
  162. assert K % 2 == 0
  163. def get_samples(x_vector, y_vector):
  164. bias_vector = y_vector - x_vector
  165. w_r = tf.math.divide(
  166. tf.abs(bias_vector) - tf.reduce_min(
  167. input_tensor=tf.abs(bias_vector), axis=2, keepdims=True)
  168. + epsilon,
  169. tf.reduce_max(
  170. input_tensor=tf.abs(bias_vector), axis=2, keepdims=True)
  171. - tf.reduce_min(
  172. input_tensor=tf.abs(bias_vector), axis=2, keepdims=True)
  173. + 2 * epsilon)
  174. R = []
  175. for i in range(K // 2):
  176. omega = eta * tf.random.normal(tf.shape(input=bias_vector), 0.0, w_r) + \
  177. (1.0 - eta) * tf.random.normal(tf.shape(input=bias_vector), 0.0, 1.0)
  178. sample = x_vector + omega * bias_vector
  179. R.append(sample)
  180. return R
  181. ALL_SAMPLES = []
  182. ALL_SAMPLES = get_samples(x_embedding, y_embedding)
  183. ALL_SAMPLES.extend(get_samples(y_embedding, x_embedding))
  184. assert len(ALL_SAMPLES) == K
  185. return tf.concat(ALL_SAMPLES, axis=0)
  186. def decoding_graph(self,
  187. encoder_output,
  188. encoder_self_attention_bias,
  189. labels,
  190. params={},
  191. embedding_augmentation=None):
  192. trg_vocab_size = params['trg_vocab_size']
  193. hidden_size = params['hidden_size']
  194. initializer = tf.compat.v1.random_normal_initializer(
  195. 0.0, hidden_size**-0.5, dtype=tf.float32)
  196. if params['shared_source_target_embedding']:
  197. with tf.compat.v1.variable_scope(
  198. 'Shared_Embedding', reuse=tf.compat.v1.AUTO_REUSE):
  199. trg_embedding = tf.compat.v1.get_variable(
  200. 'Weights', [trg_vocab_size, hidden_size],
  201. initializer=initializer)
  202. else:
  203. with tf.compat.v1.variable_scope('Target_Embedding'):
  204. trg_embedding = tf.compat.v1.get_variable(
  205. 'Weights', [trg_vocab_size, hidden_size],
  206. initializer=initializer)
  207. eos_padding = tf.zeros_like(labels, dtype=tf.int64)[:, :1]
  208. trg_seq = tf.concat([labels, eos_padding], 1)
  209. trg_mask = tf.cast(tf.not_equal(trg_seq, 0), dtype=tf.float32)
  210. shift_trg_mask = trg_mask[:, :-1]
  211. shift_trg_mask = tf.pad(
  212. tensor=shift_trg_mask,
  213. paddings=[[0, 0], [1, 0]],
  214. constant_values=1)
  215. decoder_input = tf.gather(trg_embedding, tf.cast(trg_seq, tf.int32))
  216. decoder_input *= hidden_size**0.5
  217. decoder_self_attention_bias = attention_bias(
  218. tf.shape(input=decoder_input)[1], 'causal')
  219. decoder_input = tf.pad(
  220. tensor=decoder_input, paddings=[[0, 0], [1, 0], [0, 0]])[:, :-1, :]
  221. if params['position_info_type'] == 'absolute':
  222. decoder_input = add_timing_signal(decoder_input)
  223. decoder_input = tf.nn.dropout(
  224. decoder_input, rate=1 - (1.0 - params['residual_dropout']))
  225. # training
  226. decoder_output, attention_weights = transformer_decoder(
  227. decoder_input,
  228. encoder_output,
  229. decoder_self_attention_bias,
  230. encoder_self_attention_bias,
  231. states_key=None,
  232. states_val=None,
  233. embedding_augmentation=embedding_augmentation,
  234. params=params)
  235. logits = self.prediction(decoder_output, params)
  236. on_value = params['confidence']
  237. off_value = (1.0 - params['confidence']) / tf.cast(
  238. trg_vocab_size - 1, dtype=tf.float32)
  239. soft_targets = tf.one_hot(
  240. tf.cast(trg_seq, tf.int32),
  241. depth=trg_vocab_size,
  242. on_value=on_value,
  243. off_value=off_value)
  244. mask = tf.cast(shift_trg_mask, logits.dtype)
  245. xentropy = tf.nn.softmax_cross_entropy_with_logits(
  246. logits=logits, labels=tf.stop_gradient(soft_targets)) * mask
  247. loss = tf.reduce_sum(input_tensor=xentropy) / tf.reduce_sum(
  248. input_tensor=mask)
  249. return loss
  250. def build_training_graph(self,
  251. features,
  252. labels,
  253. params,
  254. feature_embedding=None,
  255. label_embedding=None):
  256. # encode
  257. encoder_output, encoder_self_attention_bias = self.encoding_graph(
  258. features, params)
  259. embedding_augmentation = None
  260. if feature_embedding is not None and label_embedding is not None:
  261. embedding_augmentation = self.MGMC_sampling(
  262. feature_embedding, label_embedding, params)
  263. encoder_output = tf.tile(encoder_output,
  264. [params['num_of_samples'], 1, 1])
  265. encoder_self_attention_bias = tf.tile(
  266. encoder_self_attention_bias,
  267. [params['num_of_samples'], 1, 1, 1])
  268. labels = tf.tile(labels, [params['num_of_samples'], 1])
  269. # decode
  270. loss = self.decoding_graph(
  271. encoder_output,
  272. encoder_self_attention_bias,
  273. labels,
  274. params,
  275. embedding_augmentation=embedding_augmentation)
  276. return loss
  277. def transformer_model_train_fn(self, features, labels):
  278. initializer = get_initializer(self.params)
  279. with tf.compat.v1.variable_scope('NmtModel', initializer=initializer):
  280. num_gpus = self.params['num_gpus']
  281. gradient_clip_norm = self.params['gradient_clip_norm']
  282. global_step = tf.compat.v1.train.get_global_step()
  283. print(global_step)
  284. # learning rate
  285. learning_rate = get_learning_rate_decay(
  286. self.params['learning_rate'], global_step, self.params)
  287. learning_rate = tf.convert_to_tensor(
  288. value=learning_rate, dtype=tf.float32)
  289. # optimizer
  290. if self.params['optimizer'] == 'sgd':
  291. optimizer = tf.compat.v1.train.GradientDescentOptimizer(
  292. learning_rate)
  293. elif self.params['optimizer'] == 'adam':
  294. optimizer = tf.compat.v1.train.AdamOptimizer(
  295. learning_rate=learning_rate,
  296. beta1=self.params['adam_beta1'],
  297. beta2=self.params['adam_beta2'],
  298. epsilon=self.params['adam_epsilon'])
  299. else:
  300. tf.compat.v1.logging.info('optimizer not supported')
  301. sys.exit()
  302. opt = MultiStepOptimizer(optimizer, self.params['update_cycle'])
  303. def fill_gpus(inputs, num_gpus):
  304. outputs = inputs
  305. for i in range(num_gpus):
  306. outputs = tf.concat([outputs, inputs], axis=0)
  307. outputs = outputs[:num_gpus, ]
  308. return outputs
  309. features = tf.cond(
  310. pred=tf.shape(input=features)[0] < num_gpus,
  311. true_fn=lambda: fill_gpus(features, num_gpus),
  312. false_fn=lambda: features)
  313. labels = tf.cond(
  314. pred=tf.shape(input=labels)[0] < num_gpus,
  315. true_fn=lambda: fill_gpus(labels, num_gpus),
  316. false_fn=lambda: labels)
  317. if num_gpus > 0:
  318. feature_shards = shard_features(features, num_gpus)
  319. label_shards = shard_features(labels, num_gpus)
  320. else:
  321. feature_shards = [features]
  322. label_shards = [labels]
  323. if num_gpus > 0:
  324. devices = ['gpu:%d' % d for d in range(num_gpus)]
  325. else:
  326. devices = ['cpu:0']
  327. multi_grads = []
  328. sharded_losses = []
  329. for i, device in enumerate(devices):
  330. with tf.device(device), tf.compat.v1.variable_scope(
  331. tf.compat.v1.get_variable_scope(),
  332. reuse=True if i > 0 else None):
  333. with tf.name_scope('%s_%d' % ('GPU', i)):
  334. feature_output, label_output = self.build_contrastive_training_graph(
  335. feature_shards[i], label_shards[i], self.params)
  336. mle_loss = self.build_training_graph(
  337. feature_shards[i], label_shards[i], self.params,
  338. feature_output, label_output)
  339. sharded_losses.append(mle_loss)
  340. tf.compat.v1.summary.scalar('mle_loss_{}'.format(i),
  341. mle_loss)
  342. # Optimization
  343. trainable_vars_list = [
  344. v for v in tf.compat.v1.trainable_variables()
  345. if 'Semantic_Embedding' not in v.name
  346. and 'mini_xlm_encoder' not in v.name
  347. ]
  348. grads_and_vars = opt.compute_gradients(
  349. mle_loss,
  350. var_list=trainable_vars_list,
  351. colocate_gradients_with_ops=True)
  352. multi_grads.append(grads_and_vars)
  353. total_loss = tf.add_n(sharded_losses) / len(sharded_losses)
  354. # Average gradients
  355. grads_and_vars = average_gradients(multi_grads)
  356. if gradient_clip_norm > 0.0:
  357. grads, var_list = list(zip(*grads_and_vars))
  358. grads, _ = tf.clip_by_global_norm(grads, gradient_clip_norm)
  359. grads_and_vars = zip(grads, var_list)
  360. train_op = opt.apply_gradients(
  361. grads_and_vars,
  362. global_step=tf.compat.v1.train.get_global_step())
  363. return train_op, total_loss
  364. def prediction(self, decoder_output, params):
  365. hidden_size = params['hidden_size']
  366. trg_vocab_size = params['trg_vocab_size']
  367. if params['shared_embedding_and_softmax_weights']:
  368. embedding_scope = 'Shared_Embedding' if params[
  369. 'shared_source_target_embedding'] else 'Target_Embedding'
  370. with tf.compat.v1.variable_scope(embedding_scope, reuse=True):
  371. weights = tf.compat.v1.get_variable('Weights')
  372. else:
  373. weights = tf.compat.v1.get_variable('Softmax',
  374. [tgt_vocab_size, hidden_size])
  375. shape = tf.shape(input=decoder_output)[:-1]
  376. decoder_output = tf.reshape(decoder_output, [-1, hidden_size])
  377. logits = tf.matmul(decoder_output, weights, transpose_b=True)
  378. logits = tf.reshape(logits, tf.concat([shape, [trg_vocab_size]], 0))
  379. return logits
  380. def inference_func(self,
  381. encoder_output,
  382. feature_output,
  383. encoder_self_attention_bias,
  384. trg_seq,
  385. states_key,
  386. states_val,
  387. params={},
  388. is_prefix=False):
  389. trg_vocab_size = params['trg_vocab_size']
  390. hidden_size = params['hidden_size']
  391. initializer = tf.compat.v1.random_normal_initializer(
  392. 0.0, hidden_size**-0.5, dtype=tf.float32)
  393. if params['shared_source_target_embedding']:
  394. with tf.compat.v1.variable_scope(
  395. 'Shared_Embedding', reuse=tf.compat.v1.AUTO_REUSE):
  396. trg_embedding = tf.compat.v1.get_variable(
  397. 'Weights', [trg_vocab_size, hidden_size],
  398. initializer=initializer)
  399. else:
  400. with tf.compat.v1.variable_scope('Target_Embedding'):
  401. trg_embedding = tf.compat.v1.get_variable(
  402. 'Weights', [trg_vocab_size, hidden_size],
  403. initializer=initializer)
  404. decoder_input = tf.gather(trg_embedding, tf.cast(trg_seq, tf.int32))
  405. decoder_input *= hidden_size**0.5
  406. decoder_self_attention_bias = attention_bias(
  407. tf.shape(input=decoder_input)[1], 'causal')
  408. decoder_input = tf.pad(
  409. tensor=decoder_input, paddings=[[0, 0], [1, 0], [0, 0]])[:, :-1, :]
  410. if params['position_info_type'] == 'absolute':
  411. decoder_input = add_timing_signal(decoder_input)
  412. if not is_prefix:
  413. decoder_input = decoder_input[:, -1:, :]
  414. decoder_self_attention_bias = decoder_self_attention_bias[:, :,
  415. -1:, :]
  416. decoder_output, attention_weights = transformer_decoder(
  417. decoder_input,
  418. encoder_output,
  419. decoder_self_attention_bias,
  420. encoder_self_attention_bias,
  421. states_key=states_key,
  422. states_val=states_val,
  423. embedding_augmentation=feature_output,
  424. params=params)
  425. if not is_prefix:
  426. decoder_output_last = decoder_output[:, -1, :]
  427. attention_weights_last = attention_weights[:, -1, :]
  428. else:
  429. decoder_output_last = decoder_output
  430. attention_weights_last = attention_weights
  431. if params['shared_embedding_and_softmax_weights']:
  432. embedding_scope = \
  433. 'Shared_Embedding' if params['shared_source_target_embedding'] else 'Target_Embedding'
  434. with tf.compat.v1.variable_scope(embedding_scope, reuse=True):
  435. weights = tf.compat.v1.get_variable('Weights')
  436. else:
  437. weights = tf.compat.v1.get_variable('Softmax',
  438. [trg_vocab_size, hidden_size])
  439. logits = tf.matmul(decoder_output_last, weights, transpose_b=True)
  440. log_prob = tf.nn.log_softmax(logits)
  441. return log_prob, attention_weights_last, states_key, states_val
  442. def beam_search(self, features, params):
  443. beam_size = params['beam_size']
  444. trg_vocab_size = params['trg_vocab_size']
  445. hidden_size = params['hidden_size']
  446. num_decoder_layers = params['num_decoder_layers']
  447. lp_rate = params['lp_rate']
  448. max_decoded_trg_len = params['max_decoded_trg_len']
  449. src_input = features['input_wids']
  450. if 'prefix_wids' in features:
  451. prefix = features['prefix_wids']
  452. prefix_hit = features['prefix_hit']
  453. else:
  454. prefix = None
  455. prefix_hit = None
  456. batch_size = tf.shape(src_input)[0]
  457. src_input = tile_to_beam_size(src_input, beam_size)
  458. src_input = merge_first_two_dims(src_input)
  459. if prefix is not None:
  460. prefix = tf.cast(tile_to_beam_size(prefix, beam_size), tf.int32)
  461. prefix_hit = tile_to_beam_size(prefix_hit, beam_size)
  462. encoder_output, encoder_self_attention_bias = self.encoding_graph(
  463. src_input, params)
  464. source_name = 'source'
  465. if params['shared_source_target_embedding']:
  466. source_name = None
  467. feature_output = self.semantic_encoding_graph(
  468. src_input, params, name=source_name)
  469. states_key = [
  470. tf.fill([batch_size, 0, hidden_size], 0.0)
  471. for layer in range(num_decoder_layers)
  472. ]
  473. states_val = [
  474. tf.fill([batch_size, 0, hidden_size], 0.0)
  475. for layer in range(num_decoder_layers)
  476. ]
  477. for layer in range(num_decoder_layers):
  478. states_key[layer].set_shape(
  479. tf.TensorShape([None, None, hidden_size]))
  480. states_val[layer].set_shape(
  481. tf.TensorShape([None, None, hidden_size]))
  482. states_key = [
  483. tile_to_beam_size(states_key[layer], beam_size)
  484. for layer in range(num_decoder_layers)
  485. ]
  486. states_val = [
  487. tile_to_beam_size(states_val[layer], beam_size)
  488. for layer in range(num_decoder_layers)
  489. ]
  490. fixed_length = 1
  491. if prefix is not None:
  492. init_seqs = tf.concat(
  493. [prefix, tf.fill([batch_size, beam_size, 1], 0)], axis=2)
  494. fixed_length = tf.shape(init_seqs)[-1]
  495. flat_seqs = merge_first_two_dims(init_seqs)
  496. flat_states_key = [
  497. merge_first_two_dims(states_key[layer])
  498. for layer in range(num_decoder_layers)
  499. ]
  500. flat_states_val = [
  501. merge_first_two_dims(states_val[layer])
  502. for layer in range(num_decoder_layers)
  503. ]
  504. step_log_probs, step_attn_weights, step_states_key, step_states_val = self.inference_func(
  505. encoder_output,
  506. feature_output,
  507. encoder_self_attention_bias,
  508. flat_seqs,
  509. flat_states_key,
  510. flat_states_val,
  511. params=params,
  512. is_prefix=True)
  513. states_key = [
  514. split_first_two_dims(step_states_key[layer], batch_size,
  515. beam_size)
  516. for layer in range(num_decoder_layers)
  517. ]
  518. states_val = [
  519. split_first_two_dims(step_states_val[layer], batch_size,
  520. beam_size)
  521. for layer in range(num_decoder_layers)
  522. ]
  523. prefix_hit = merge_first_two_dims(prefix_hit)
  524. log_probs = tf.where(
  525. prefix_hit, step_log_probs[:, -1, :],
  526. tf.ones_like(step_log_probs[:, -1, :]) * tf.float32.min)
  527. init_seqs = tf.concat([
  528. flat_seqs[:, :-1],
  529. tf.expand_dims(
  530. tf.cast(tf.argmax(log_probs, -1), tf.int32), -1)
  531. ], -1)
  532. init_seqs = split_first_two_dims(init_seqs, batch_size, beam_size)
  533. init_seqs = tf.concat(
  534. [init_seqs, tf.fill([batch_size, beam_size, 1], 0)], axis=2)
  535. else:
  536. init_seqs = tf.fill([batch_size, beam_size, 1], 0)
  537. init_log_probs = \
  538. tf.constant([[0.] + [tf.float32.min] * (beam_size - 1)])
  539. init_log_probs = tf.tile(init_log_probs, [batch_size, 1])
  540. init_scores = tf.zeros_like(init_log_probs)
  541. fin_seqs = init_seqs
  542. fin_scores = tf.fill([batch_size, beam_size], tf.float32.min)
  543. fin_flags = tf.cast(tf.fill([batch_size, beam_size], 0), tf.bool)
  544. state = BeamSearchState(
  545. inputs=(init_seqs, init_log_probs, init_scores),
  546. state=(states_key, states_val),
  547. finish=(fin_flags, fin_seqs, fin_scores),
  548. )
  549. def _beam_search_step(time, state):
  550. seqs, log_probs = state.inputs[:2]
  551. states_key, states_val = state.state
  552. flat_seqs = merge_first_two_dims(seqs)
  553. flat_states_key = [
  554. merge_first_two_dims(states_key[layer])
  555. for layer in range(num_decoder_layers)
  556. ]
  557. flat_states_val = [
  558. merge_first_two_dims(states_val[layer])
  559. for layer in range(num_decoder_layers)
  560. ]
  561. step_log_probs, step_attn_weights, step_states_key, step_states_val = self.inference_func(
  562. encoder_output,
  563. feature_output,
  564. encoder_self_attention_bias,
  565. flat_seqs,
  566. flat_states_key,
  567. flat_states_val,
  568. params=params,
  569. is_prefix=False)
  570. step_log_probs = split_first_two_dims(step_log_probs, batch_size,
  571. beam_size)
  572. curr_log_probs = tf.expand_dims(log_probs, 2) + step_log_probs
  573. next_states_key = [
  574. split_first_two_dims(step_states_key[layer], batch_size,
  575. beam_size)
  576. for layer in range(num_decoder_layers)
  577. ]
  578. next_states_val = [
  579. split_first_two_dims(step_states_val[layer], batch_size,
  580. beam_size)
  581. for layer in range(num_decoder_layers)
  582. ]
  583. # Apply length penalty
  584. length_penalty = tf.pow(
  585. (5.0 + tf.cast(time + 1, dtype=tf.float32)) / 6.0, lp_rate)
  586. curr_scores = curr_log_probs / length_penalty
  587. # Select top-k candidates
  588. # [batch_size, beam_size * vocab_size]
  589. curr_scores = tf.reshape(curr_scores,
  590. [-1, beam_size * trg_vocab_size])
  591. # [batch_size, 2 * beam_size]
  592. top_scores, top_indices = tf.nn.top_k(curr_scores, k=2 * beam_size)
  593. # Shape: [batch_size, 2 * beam_size]
  594. beam_indices = top_indices // trg_vocab_size
  595. symbol_indices = top_indices % trg_vocab_size
  596. # Expand sequences
  597. # [batch_size, 2 * beam_size, time]
  598. candidate_seqs = gather_2d(seqs, beam_indices)
  599. candidate_seqs = tf.concat(
  600. [candidate_seqs[:, :, :-1],
  601. tf.expand_dims(symbol_indices, 2)],
  602. axis=2)
  603. pad_seqs = tf.fill([batch_size, 2 * beam_size, 1],
  604. tf.constant(0, tf.int32))
  605. candidate_seqs = tf.concat([candidate_seqs, pad_seqs], axis=2)
  606. # Expand sequences
  607. # Suppress finished sequences
  608. flags = tf.equal(symbol_indices, 0)
  609. # [batch, 2 * beam_size]
  610. alive_scores = top_scores + tf.cast(
  611. flags, dtype=tf.float32) * tf.float32.min
  612. # [batch, beam_size]
  613. alive_scores, alive_indices = tf.nn.top_k(alive_scores, beam_size)
  614. alive_symbols = gather_2d(symbol_indices, alive_indices)
  615. alive_indices = gather_2d(beam_indices, alive_indices)
  616. alive_seqs = gather_2d(seqs, alive_indices)
  617. alive_seqs = tf.concat(
  618. [alive_seqs[:, :, :-1],
  619. tf.expand_dims(alive_symbols, 2)],
  620. axis=2)
  621. pad_seqs = tf.fill([batch_size, beam_size, 1],
  622. tf.constant(0, tf.int32))
  623. alive_seqs = tf.concat([alive_seqs, pad_seqs], axis=2)
  624. alive_states_key = [
  625. gather_2d(next_states_key[layer], alive_indices)
  626. for layer in range(num_decoder_layers)
  627. ]
  628. alive_states_val = [
  629. gather_2d(next_states_val[layer], alive_indices)
  630. for layer in range(num_decoder_layers)
  631. ]
  632. alive_log_probs = alive_scores * length_penalty
  633. # Select finished sequences
  634. prev_fin_flags, prev_fin_seqs, prev_fin_scores = state.finish
  635. # [batch, 2 * beam_size]
  636. step_fin_scores = top_scores + (
  637. 1.0 - tf.cast(flags, dtype=tf.float32)) * tf.float32.min
  638. # [batch, 3 * beam_size]
  639. fin_flags = tf.concat([prev_fin_flags, flags], axis=1)
  640. fin_scores = tf.concat([prev_fin_scores, step_fin_scores], axis=1)
  641. # [batch, beam_size]
  642. fin_scores, fin_indices = tf.nn.top_k(fin_scores, beam_size)
  643. fin_flags = gather_2d(fin_flags, fin_indices)
  644. pad_seqs = tf.fill([batch_size, beam_size, 1],
  645. tf.constant(0, tf.int32))
  646. prev_fin_seqs = tf.concat([prev_fin_seqs, pad_seqs], axis=2)
  647. fin_seqs = tf.concat([prev_fin_seqs, candidate_seqs], axis=1)
  648. fin_seqs = gather_2d(fin_seqs, fin_indices)
  649. new_state = BeamSearchState(
  650. inputs=(alive_seqs, alive_log_probs, alive_scores),
  651. state=(alive_states_key, alive_states_val),
  652. finish=(fin_flags, fin_seqs, fin_scores),
  653. )
  654. return time + 1, new_state
  655. def _is_finished(t, s):
  656. log_probs = s.inputs[1]
  657. finished_flags = s.finish[0]
  658. finished_scores = s.finish[2]
  659. max_lp = tf.pow(
  660. ((5.0 + tf.cast(max_decoded_trg_len, dtype=tf.float32)) / 6.0),
  661. lp_rate)
  662. best_alive_score = log_probs[:, 0] / max_lp
  663. worst_finished_score = tf.reduce_min(
  664. input_tensor=finished_scores
  665. * tf.cast(finished_flags, dtype=tf.float32),
  666. axis=1)
  667. add_mask = 1.0 - tf.cast(
  668. tf.reduce_any(input_tensor=finished_flags, axis=1),
  669. dtype=tf.float32)
  670. worst_finished_score += tf.float32.min * add_mask
  671. bound_is_met = tf.reduce_all(
  672. input_tensor=tf.greater(worst_finished_score,
  673. best_alive_score))
  674. cond = tf.logical_and(
  675. tf.less(t, max_decoded_trg_len), tf.logical_not(bound_is_met))
  676. return cond
  677. def _loop_fn(t, s):
  678. outs = _beam_search_step(t, s)
  679. return outs
  680. time = tf.constant(0, name='time')
  681. shape_invariants = BeamSearchState(
  682. inputs=(tf.TensorShape([None, None, None]),
  683. tf.TensorShape([None, None]), tf.TensorShape([None,
  684. None])),
  685. state=([
  686. tf.TensorShape([None, None, None, hidden_size])
  687. for layer in range(num_decoder_layers)
  688. ], [
  689. tf.TensorShape([None, None, None, hidden_size])
  690. for layer in range(num_decoder_layers)
  691. ]),
  692. finish=(tf.TensorShape([None,
  693. None]), tf.TensorShape([None, None, None]),
  694. tf.TensorShape([None, None])))
  695. outputs = tf.while_loop(
  696. cond=_is_finished,
  697. body=_loop_fn,
  698. loop_vars=[time, state],
  699. shape_invariants=[tf.TensorShape([]), shape_invariants],
  700. parallel_iterations=1,
  701. back_prop=False)
  702. final_state = outputs[1]
  703. alive_seqs = final_state.inputs[0]
  704. alive_scores = final_state.inputs[2]
  705. final_flags = final_state.finish[0]
  706. final_seqs = final_state.finish[1]
  707. final_scores = final_state.finish[2]
  708. alive_seqs.set_shape([None, beam_size, None])
  709. final_seqs.set_shape([None, beam_size, None])
  710. final_seqs = tf.compat.v1.where(
  711. tf.reduce_any(input_tensor=final_flags, axis=1), final_seqs,
  712. alive_seqs)
  713. final_scores = tf.compat.v1.where(
  714. tf.reduce_any(input_tensor=final_flags, axis=1), final_scores,
  715. alive_scores)
  716. final_seqs = final_seqs[:, :, fixed_length - 1:-1]
  717. return final_seqs, final_scores
  718. class BeamSearchState(
  719. namedtuple('BeamSearchState', ('inputs', 'state', 'finish'))):
  720. pass
  721. def tile_to_beam_size(tensor, beam_size):
  722. """Tiles a given tensor by beam_size. """
  723. tensor = tf.expand_dims(tensor, axis=1)
  724. tile_dims = [1] * tensor.shape.ndims
  725. tile_dims[1] = beam_size
  726. return tf.tile(tensor, tile_dims)
  727. def infer_shape(x):
  728. x = tf.convert_to_tensor(x)
  729. if x.shape.dims is None:
  730. return tf.shape(x)
  731. static_shape = x.shape.as_list()
  732. dynamic_shape = tf.shape(x)
  733. ret = []
  734. for i in range(len(static_shape)):
  735. dim = static_shape[i]
  736. if dim is None:
  737. dim = dynamic_shape[i]
  738. ret.append(dim)
  739. return ret
  740. def split_first_two_dims(tensor, dim_0, dim_1):
  741. shape = infer_shape(tensor)
  742. new_shape = [dim_0] + [dim_1] + shape[1:]
  743. return tf.reshape(tensor, new_shape)
  744. def merge_first_two_dims(tensor):
  745. shape = infer_shape(tensor)
  746. shape[0] *= shape[1]
  747. shape.pop(1)
  748. return tf.reshape(tensor, shape)
  749. def gather_2d(params, indices, name=None):
  750. """ Gather the 2nd dimension given indices
  751. :param params: A tensor with shape [batch_size, M, ...]
  752. :param indices: A tensor with shape [batch_size, N]
  753. :param name: An optional string
  754. :return: A tensor with shape [batch_size, N, ...]
  755. """
  756. batch_size = tf.shape(params)[0]
  757. range_size = tf.shape(indices)[1]
  758. batch_pos = tf.range(batch_size * range_size) // range_size
  759. batch_pos = tf.reshape(batch_pos, [batch_size, range_size])
  760. indices = tf.stack([batch_pos, indices], axis=-1)
  761. output = tf.gather_nd(params, indices, name=name)
  762. return output
  763. def linear(inputs, output_size, bias, concat=True, dtype=None, scope=None):
  764. with tf.compat.v1.variable_scope(
  765. scope, default_name='linear', values=[inputs], dtype=dtype):
  766. if not isinstance(inputs, (list, tuple)):
  767. inputs = [inputs]
  768. input_size = [item.get_shape()[-1] for item in inputs]
  769. if len(inputs) != len(input_size):
  770. raise RuntimeError('inputs and input_size unmatched!')
  771. output_shape = tf.concat([tf.shape(inputs[0])[:-1], [output_size]],
  772. axis=0)
  773. # Flatten to 2D
  774. inputs = [tf.reshape(inp, [-1, inp.shape[-1]]) for inp in inputs]
  775. results = []
  776. if concat:
  777. input_size = sum(input_size)
  778. inputs = tf.concat(inputs, 1)
  779. shape = [input_size, output_size]
  780. matrix = tf.compat.v1.get_variable('matrix', shape)
  781. results.append(tf.matmul(inputs, matrix))
  782. else:
  783. for i in range(len(input_size)):
  784. shape = [input_size[i], output_size]
  785. name = 'matrix_%d' % i
  786. matrix = tf.compat.v1.get_variable(name, shape)
  787. results.append(tf.matmul(inputs[i], matrix))
  788. output = tf.add_n(results)
  789. if bias:
  790. shape = [output_size]
  791. bias = tf.compat.v1.get_variable('bias', shape)
  792. output = tf.nn.bias_add(output, bias)
  793. output = tf.reshape(output, output_shape)
  794. return output
  795. def layer_norm(inputs, epsilon=1e-6, name=None, reuse=None):
  796. with tf.compat.v1.variable_scope(
  797. name, default_name='layer_norm', values=[inputs], reuse=reuse):
  798. channel_size = inputs.get_shape().as_list()[-1]
  799. scale = tf.compat.v1.get_variable(
  800. 'layer_norm_scale', [channel_size],
  801. initializer=tf.ones_initializer())
  802. offset = tf.compat.v1.get_variable(
  803. 'layer_norm_offset', [channel_size],
  804. initializer=tf.zeros_initializer())
  805. mean = tf.reduce_mean(inputs, -1, True)
  806. variance = tf.reduce_mean(tf.square(inputs - mean), -1, True)
  807. norm_inputs = (inputs - mean) * tf.compat.v1.rsqrt(variance + epsilon)
  808. return norm_inputs * scale + offset
  809. def _layer_process(x, mode):
  810. if not mode or mode == 'none':
  811. return x
  812. elif mode == 'layer_norm':
  813. return layer_norm(x)
  814. else:
  815. raise ValueError('Unknown mode %s' % mode)
  816. def _residual_fn(x, y, keep_prob=None):
  817. if keep_prob and keep_prob < 1.0:
  818. y = tf.nn.dropout(y, rate=1 - (keep_prob))
  819. return x + y
  820. def embedding_augmentation_layer(x, embedding_augmentation, params, name=None):
  821. hidden_size = params['hidden_size']
  822. keep_prob = 1.0 - params['relu_dropout']
  823. with tf.compat.v1.variable_scope(
  824. name,
  825. default_name='embedding_augmentation_layer',
  826. values=[x, embedding_augmentation]):
  827. with tf.compat.v1.variable_scope('input_layer'):
  828. hidden = linear(embedding_augmentation, hidden_size, True, True)
  829. hidden = tf.nn.relu(hidden)
  830. if keep_prob and keep_prob < 1.0:
  831. hidden = tf.nn.dropout(hidden, rate=1 - (keep_prob))
  832. with tf.compat.v1.variable_scope('output_layer'):
  833. output = linear(hidden, hidden_size, True, True)
  834. return x + output
  835. def transformer_ffn_layer(x, params, name=None):
  836. filter_size = params['filter_size']
  837. hidden_size = params['hidden_size']
  838. keep_prob = 1.0 - params['relu_dropout']
  839. with tf.compat.v1.variable_scope(
  840. name, default_name='ffn_layer', values=[x]):
  841. with tf.compat.v1.variable_scope('input_layer'):
  842. hidden = linear(x, filter_size, True, True)
  843. hidden = tf.nn.relu(hidden)
  844. if keep_prob and keep_prob < 1.0:
  845. hidden = tf.nn.dropout(hidden, rate=1 - (keep_prob))
  846. with tf.compat.v1.variable_scope('output_layer'):
  847. output = linear(hidden, hidden_size, True, True)
  848. return output
  849. def transformer_encoder(encoder_input,
  850. encoder_self_attention_bias,
  851. mask,
  852. params={},
  853. name='encoder'):
  854. num_encoder_layers = params['num_encoder_layers']
  855. hidden_size = params['hidden_size']
  856. num_heads = params['num_heads']
  857. residual_dropout = params['residual_dropout']
  858. attention_dropout = params['attention_dropout']
  859. layer_preproc = params['layer_preproc']
  860. layer_postproc = params['layer_postproc']
  861. x = encoder_input
  862. mask = tf.expand_dims(mask, 2)
  863. with tf.compat.v1.variable_scope(name, reuse=tf.compat.v1.AUTO_REUSE):
  864. for layer in range(num_encoder_layers):
  865. with tf.compat.v1.variable_scope('layer_%d' % layer):
  866. max_relative_dis = params['max_relative_dis'] \
  867. if params['position_info_type'] == 'relative' else None
  868. o, w = multihead_attention(
  869. _layer_process(x, layer_preproc),
  870. None,
  871. encoder_self_attention_bias,
  872. hidden_size,
  873. hidden_size,
  874. hidden_size,
  875. num_heads,
  876. attention_dropout,
  877. max_relative_dis=max_relative_dis,
  878. name='encoder_self_attention')
  879. x = _residual_fn(x, o, 1.0 - residual_dropout)
  880. x = _layer_process(x, layer_postproc)
  881. o = transformer_ffn_layer(
  882. _layer_process(x, layer_preproc), params)
  883. x = _residual_fn(x, o, 1.0 - residual_dropout)
  884. x = _layer_process(x, layer_postproc)
  885. x = tf.multiply(x, mask)
  886. return _layer_process(x, layer_preproc)
  887. def transformer_semantic_encoder(encoder_input,
  888. encoder_self_attention_bias,
  889. mask,
  890. params={},
  891. name='mini_xlm_encoder'):
  892. num_encoder_layers = params['num_semantic_encoder_layers']
  893. hidden_size = params['hidden_size']
  894. num_heads = params['num_heads']
  895. residual_dropout = params['residual_dropout']
  896. attention_dropout = params['attention_dropout']
  897. layer_preproc = params['layer_preproc']
  898. layer_postproc = params['layer_postproc']
  899. x = encoder_input
  900. mask = tf.expand_dims(mask, 2)
  901. with tf.compat.v1.variable_scope(name, reuse=tf.compat.v1.AUTO_REUSE):
  902. for layer in range(num_encoder_layers):
  903. with tf.compat.v1.variable_scope('layer_%d' % layer):
  904. max_relative_dis = params['max_relative_dis']
  905. o, w = multihead_attention(
  906. _layer_process(x, layer_preproc),
  907. None,
  908. encoder_self_attention_bias,
  909. hidden_size,
  910. hidden_size,
  911. hidden_size,
  912. num_heads,
  913. attention_dropout,
  914. max_relative_dis=max_relative_dis,
  915. name='encoder_self_attention')
  916. x = _residual_fn(x, o, 1.0 - residual_dropout)
  917. x = _layer_process(x, layer_postproc)
  918. o = transformer_ffn_layer(
  919. _layer_process(x, layer_preproc), params)
  920. x = _residual_fn(x, o, 1.0 - residual_dropout)
  921. x = _layer_process(x, layer_postproc)
  922. x = tf.multiply(x, mask)
  923. with tf.compat.v1.variable_scope(
  924. 'pooling_layer', reuse=tf.compat.v1.AUTO_REUSE):
  925. output = tf.reduce_sum(
  926. input_tensor=x, axis=1) / tf.reduce_sum(
  927. input_tensor=mask, axis=1)
  928. output = linear(
  929. tf.expand_dims(output, axis=1), hidden_size, True, True)
  930. return _layer_process(output, layer_preproc)
  931. def transformer_decoder(decoder_input,
  932. encoder_output,
  933. decoder_self_attention_bias,
  934. encoder_decoder_attention_bias,
  935. states_key=None,
  936. states_val=None,
  937. embedding_augmentation=None,
  938. params={},
  939. name='decoder'):
  940. num_decoder_layers = params['num_decoder_layers']
  941. hidden_size = params['hidden_size']
  942. num_heads = params['num_heads']
  943. residual_dropout = params['residual_dropout']
  944. attention_dropout = params['attention_dropout']
  945. layer_preproc = params['layer_preproc']
  946. layer_postproc = params['layer_postproc']
  947. x = decoder_input
  948. with tf.compat.v1.variable_scope(name, reuse=tf.compat.v1.AUTO_REUSE):
  949. for layer in range(num_decoder_layers):
  950. with tf.compat.v1.variable_scope('layer_%d' % layer):
  951. max_relative_dis = params['max_relative_dis'] \
  952. if params['position_info_type'] == 'relative' else None
  953. # continuous semantic augmentation
  954. if embedding_augmentation is not None:
  955. x = embedding_augmentation_layer(
  956. x, _layer_process(embedding_augmentation,
  957. layer_preproc), params)
  958. x = _layer_process(x, layer_postproc)
  959. o, w = multihead_attention(
  960. _layer_process(x, layer_preproc),
  961. None,
  962. decoder_self_attention_bias,
  963. hidden_size,
  964. hidden_size,
  965. hidden_size,
  966. num_heads,
  967. attention_dropout,
  968. states_key=states_key,
  969. states_val=states_val,
  970. layer=layer,
  971. max_relative_dis=max_relative_dis,
  972. name='decoder_self_attention')
  973. x = _residual_fn(x, o, 1.0 - residual_dropout)
  974. x = _layer_process(x, layer_postproc)
  975. o, w = multihead_attention(
  976. _layer_process(x, layer_preproc),
  977. encoder_output,
  978. encoder_decoder_attention_bias,
  979. hidden_size,
  980. hidden_size,
  981. hidden_size,
  982. num_heads,
  983. attention_dropout,
  984. max_relative_dis=max_relative_dis,
  985. name='encdec_attention')
  986. x = _residual_fn(x, o, 1.0 - residual_dropout)
  987. x = _layer_process(x, layer_postproc)
  988. o = transformer_ffn_layer(
  989. _layer_process(x, layer_preproc), params)
  990. x = _residual_fn(x, o, 1.0 - residual_dropout)
  991. x = _layer_process(x, layer_postproc)
  992. return _layer_process(x, layer_preproc), w
  993. def add_timing_signal(x, min_timescale=1.0, max_timescale=1.0e4):
  994. length = tf.shape(x)[1]
  995. channels = tf.shape(x)[2]
  996. position = tf.cast(tf.range(length), tf.float32)
  997. num_timescales = channels // 2
  998. log_timescale_increment = \
  999. (math.log(float(max_timescale) / float(min_timescale)) / (tf.cast(num_timescales, tf.float32) - 1))
  1000. inv_timescales = min_timescale * tf.exp(
  1001. tf.cast(tf.range(num_timescales), tf.float32)
  1002. * -log_timescale_increment)
  1003. scaled_time = \
  1004. tf.expand_dims(position, 1) * tf.expand_dims(inv_timescales, 0)
  1005. signal = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=1)
  1006. signal = tf.pad(signal, [[0, 0], [0, tf.compat.v1.mod(channels, 2)]])
  1007. signal = tf.reshape(signal, [1, length, channels])
  1008. return x + tf.cast(signal, x.dtype)
  1009. def attention_bias(inputs, mode, inf=-1e9, dtype=None):
  1010. if dtype is None:
  1011. dtype = tf.float32
  1012. if dtype != tf.float32:
  1013. inf = dtype.min
  1014. if mode == 'masking':
  1015. mask = inputs
  1016. ret = (1.0 - mask) * inf
  1017. ret = tf.expand_dims(tf.expand_dims(ret, 1), 1)
  1018. elif mode == 'causal':
  1019. length = inputs
  1020. lower_triangle = tf.linalg.band_part(
  1021. tf.fill([length, length], 1.0), -1, 0)
  1022. ret = inf * (1.0 - lower_triangle)
  1023. ret = tf.reshape(ret, [1, 1, length, length])
  1024. else:
  1025. raise ValueError('Unknown mode %s' % mode)
  1026. return tf.cast(ret, dtype)
  1027. def split_heads(x, num_heads):
  1028. n = num_heads
  1029. old_shape = x.get_shape().dims
  1030. ndims = x.shape.ndims
  1031. last = old_shape[-1]
  1032. new_shape = old_shape[:-1] + [n] + [last // n if last else None]
  1033. ret = tf.reshape(x, tf.concat([tf.shape(x)[:-1], [n, -1]], 0))
  1034. ret.set_shape(new_shape)
  1035. perm = [0, ndims - 1] + [i for i in range(1, ndims - 1)] + [ndims]
  1036. return tf.transpose(ret, perm)
  1037. def dot_product_attention(q,
  1038. k,
  1039. v,
  1040. bias,
  1041. dropout_rate=0.0,
  1042. name=None,
  1043. rpr=None):
  1044. with tf.compat.v1.variable_scope(
  1045. name, default_name='dot_product_attention', values=[q, k, v]):
  1046. q_shape = tf.shape(q)
  1047. bs, hd, lq, dk = q_shape[0], q_shape[1], q_shape[2], q_shape[3]
  1048. lk = tf.shape(k)[2]
  1049. dv = tf.shape(v)[3]
  1050. if rpr is not None:
  1051. rpr_k, rpr_v = rpr['rpr_k'], rpr[
  1052. 'rpr_v'] # (lq, lk, dk), (lq, lk, dv)
  1053. if rpr is None:
  1054. logits = tf.matmul(q, k, transpose_b=True)
  1055. else: # self-attention with relative position representation
  1056. logits_part1 = tf.matmul(q, k, transpose_b=True) # bs, hd, lq, lk
  1057. q = tf.reshape(tf.transpose(q, [2, 0, 1, 3]),
  1058. [lq, bs * hd, dk]) # lq, bs*hd, dk
  1059. logits_part2 = tf.matmul(q,
  1060. tf.transpose(rpr_k,
  1061. [0, 2, 1])) # lq, bs*hd, lk
  1062. logits_part2 = tf.reshape(
  1063. tf.transpose(logits_part2, [1, 0, 2]), [bs, hd, lq, lk])
  1064. logits = logits_part1 + logits_part2 # bs, hd, lq, lk
  1065. if bias is not None:
  1066. logits += bias
  1067. weights = tf.nn.softmax(logits, name='attention_weights')
  1068. if dropout_rate > 0.0:
  1069. weights = tf.nn.dropout(weights, 1.0 - dropout_rate)
  1070. if rpr is None:
  1071. return tf.matmul(weights, v), weights
  1072. else:
  1073. outputs_part1 = tf.matmul(weights, v) # bs, hd, lq, dv
  1074. weights = tf.reshape(
  1075. tf.transpose(weights, [2, 0, 1, 3]),
  1076. [lq, bs * hd, lk]) # lq, bs*hd, lk
  1077. outputs_part2 = tf.matmul(weights, rpr_v) # lq, bs*hd, dv
  1078. outputs_part2 = tf.reshape(
  1079. tf.transpose(outputs_part2, [1, 0, 2]), [bs, hd, lq, dv])
  1080. outputs = outputs_part1 + outputs_part2 # bs, hd, lq, dv
  1081. weights = tf.reshape(
  1082. tf.transpose(weights, [1, 0, 2]),
  1083. [bs, hd, lq, lk]) # bs, hd, lq, lk
  1084. return outputs, weights
  1085. def combine_heads(x):
  1086. x = tf.transpose(x, [0, 2, 1, 3])
  1087. old_shape = x.get_shape().dims
  1088. a, b = old_shape[-2:]
  1089. new_shape = old_shape[:-2] + [a * b if a and b else None]
  1090. x = tf.reshape(x, tf.concat([tf.shape(x)[:-2], [-1]], 0))
  1091. x.set_shape(new_shape)
  1092. return x
  1093. def create_rpr(orginal_var,
  1094. length_q,
  1095. length_kv,
  1096. max_relative_dis,
  1097. name='create_rpr'):
  1098. with tf.name_scope(name):
  1099. idxs = tf.reshape(tf.range(length_kv), [-1, 1]) # only self-attention
  1100. idys = tf.reshape(tf.range(length_kv), [1, -1])
  1101. ids = idxs - idys
  1102. ids = ids + max_relative_dis
  1103. ids = tf.maximum(ids, 0)
  1104. ids = tf.minimum(ids, 2 * max_relative_dis)
  1105. ids = ids[-length_q:, :]
  1106. rpr = tf.gather(orginal_var, ids)
  1107. return rpr
  1108. def multihead_attention(queries,
  1109. memories,
  1110. bias,
  1111. key_depth,
  1112. value_depth,
  1113. output_depth,
  1114. num_heads,
  1115. dropout_rate,
  1116. states_key=None,
  1117. states_val=None,
  1118. layer=0,
  1119. max_relative_dis=None,
  1120. name=None):
  1121. if key_depth % num_heads != 0:
  1122. raise ValueError(
  1123. 'Key size (%d) must be divisible by the number of attention heads (%d).'
  1124. % (key_size, num_heads))
  1125. if value_depth % num_heads != 0:
  1126. raise ValueError(
  1127. 'Value size (%d) must be divisible by the number of attention heads (%d).'
  1128. % (value_size, num_heads))
  1129. with tf.compat.v1.variable_scope(
  1130. name, default_name='multihead_attention',
  1131. values=[queries, memories]):
  1132. if memories is None:
  1133. # self attention
  1134. combined = linear(
  1135. queries,
  1136. key_depth * 2 + value_depth,
  1137. True,
  1138. True,
  1139. scope='qkv_transform')
  1140. q, k, v = tf.split(
  1141. combined, [key_depth, key_depth, value_depth], axis=2)
  1142. else:
  1143. q = linear(queries, key_depth, True, True, scope='q_transform')
  1144. combined = linear(
  1145. memories,
  1146. key_depth + value_depth,
  1147. True,
  1148. True,
  1149. scope='kv_transform')
  1150. k, v = tf.split(combined, [key_depth, value_depth], axis=2)
  1151. if states_key is not None:
  1152. k = states_key[layer] = tf.concat([states_key[layer], k], axis=1)
  1153. if states_val is not None:
  1154. v = states_val[layer] = tf.concat([states_val[layer], v], axis=1)
  1155. q = split_heads(q, num_heads)
  1156. k = split_heads(k, num_heads)
  1157. v = split_heads(v, num_heads)
  1158. key_depth_per_head = key_depth // num_heads
  1159. q *= key_depth_per_head**-0.5
  1160. length_q = tf.shape(q)[2]
  1161. length_kv = tf.shape(k)[2]
  1162. # relative position representation (only in self-attention)
  1163. if memories is None and max_relative_dis is not None:
  1164. rpr_k = tf.compat.v1.get_variable(
  1165. 'rpr_k', [2 * max_relative_dis + 1, key_depth // num_heads])
  1166. rpr_v = tf.compat.v1.get_variable(
  1167. 'rpr_v', [2 * max_relative_dis + 1, value_depth // num_heads])
  1168. rpr_k = create_rpr(rpr_k, length_q, length_kv, max_relative_dis)
  1169. rpr_v = create_rpr(rpr_v, length_q, length_kv, max_relative_dis)
  1170. rpr = {'rpr_k': rpr_k, 'rpr_v': rpr_v}
  1171. x, w = dot_product_attention(q, k, v, bias, dropout_rate, rpr=rpr)
  1172. else:
  1173. x, w = dot_product_attention(q, k, v, bias, dropout_rate)
  1174. x = combine_heads(x)
  1175. w = tf.reduce_mean(w, 1)
  1176. x = linear(x, output_depth, True, True, scope='output_transform')
  1177. return x, w
  1178. def get_initializer(params):
  1179. if params['initializer'] == 'uniform':
  1180. max_val = params['initializer_scale']
  1181. return tf.compat.v1.random_uniform_initializer(-max_val, max_val)
  1182. elif params['initializer'] == 'normal':
  1183. return tf.compat.v1.random_normal_initializer(
  1184. 0.0, params['initializer_scale'])
  1185. elif params['initializer'] == 'normal_unit_scaling':
  1186. return tf.compat.v1.variance_scaling_initializer(
  1187. params['initializer_scale'], mode='fan_avg', distribution='normal')
  1188. elif params['initializer'] == 'uniform_unit_scaling':
  1189. return tf.compat.v1.variance_scaling_initializer(
  1190. params['initializer_scale'],
  1191. mode='fan_avg',
  1192. distribution='uniform')
  1193. else:
  1194. raise ValueError('Unrecognized initializer: %s'
  1195. % params['initializer'])
  1196. def get_learning_rate_decay(learning_rate, global_step, params):
  1197. if params['learning_rate_decay'] in ['linear_warmup_rsqrt_decay', 'noam']:
  1198. step = tf.cast(global_step, dtype=tf.float32)
  1199. warmup_steps = tf.cast(params['warmup_steps'], dtype=tf.float32)
  1200. multiplier = params['hidden_size']**-0.5
  1201. decay = multiplier * tf.minimum((step + 1) * (warmup_steps**-1.5),
  1202. (step + 1)**-0.5)
  1203. return learning_rate * decay
  1204. elif params['learning_rate_decay'] == 'piecewise_constant':
  1205. return tf.compat.v1.train.piecewise_constant(
  1206. tf.cast(global_step, dtype=tf.int32),
  1207. params['learning_rate_boundaries'], params['learning_rate_values'])
  1208. elif params['learning_rate_decay'] == 'none':
  1209. return learning_rate
  1210. else:
  1211. raise ValueError('Unknown learning_rate_decay')
  1212. def average_gradients(tower_grads):
  1213. average_grads = []
  1214. for grad_and_vars in zip(*tower_grads):
  1215. grads = []
  1216. for g, _ in grad_and_vars:
  1217. expanded_g = tf.expand_dims(g, 0)
  1218. grads.append(expanded_g)
  1219. grad = tf.concat(axis=0, values=grads)
  1220. grad = tf.reduce_mean(grad, 0)
  1221. v = grad_and_vars[0][1]
  1222. grad_and_var = (grad, v)
  1223. average_grads.append(grad_and_var)
  1224. return average_grads
  1225. _ENGINE = None
  1226. def all_reduce(tensor):
  1227. if _ENGINE is None:
  1228. return tensor
  1229. return _ENGINE.allreduce(tensor, compression=_ENGINE.Compression.fp16)
  1230. class MultiStepOptimizer(tf.compat.v1.train.Optimizer):
  1231. def __init__(self,
  1232. optimizer,
  1233. step=1,
  1234. use_locking=False,
  1235. name='MultiStepOptimizer'):
  1236. super(MultiStepOptimizer, self).__init__(use_locking, name)
  1237. self._optimizer = optimizer
  1238. self._step = step
  1239. self._step_t = tf.convert_to_tensor(step, name='step')
  1240. def _all_reduce(self, tensor):
  1241. with tf.name_scope(self._name + '_Allreduce'):
  1242. if tensor is None:
  1243. return tensor
  1244. if isinstance(tensor, tf.IndexedSlices):
  1245. tensor = tf.convert_to_tensor(tensor)
  1246. return all_reduce(tensor)
  1247. def compute_gradients(self,
  1248. loss,
  1249. var_list=None,
  1250. gate_gradients=tf.compat.v1.train.Optimizer.GATE_OP,
  1251. aggregation_method=None,
  1252. colocate_gradients_with_ops=False,
  1253. grad_loss=None):
  1254. grads_and_vars = self._optimizer.compute_gradients(
  1255. loss, var_list, gate_gradients, aggregation_method,
  1256. colocate_gradients_with_ops, grad_loss)
  1257. grads, var_list = list(zip(*grads_and_vars))
  1258. # Do not create extra variables when step is 1
  1259. if self._step == 1:
  1260. grads = [self._all_reduce(t) for t in grads]
  1261. return list(zip(grads, var_list))
  1262. first_var = min(var_list, key=lambda x: x.name)
  1263. iter_var = self._create_non_slot_variable(
  1264. initial_value=0 if self._step == 1 else 1,
  1265. name='iter',
  1266. colocate_with=first_var)
  1267. new_grads = []
  1268. for grad, var in zip(grads, var_list):
  1269. grad_acc = self._zeros_slot(var, 'grad_acc', self._name)
  1270. if isinstance(grad, tf.IndexedSlices):
  1271. grad_acc = tf.scatter_add(
  1272. grad_acc,
  1273. grad.indices,
  1274. grad.values,
  1275. use_locking=self._use_locking)
  1276. else:
  1277. grad_acc = tf.assign_add(
  1278. grad_acc, grad, use_locking=self._use_locking)
  1279. def _acc_grad():
  1280. return grad_acc
  1281. def _avg_grad():
  1282. return self._all_reduce(grad_acc / self._step)
  1283. grad = tf.cond(tf.equal(iter_var, 0), _avg_grad, _acc_grad)
  1284. new_grads.append(grad)
  1285. return list(zip(new_grads, var_list))
  1286. def apply_gradients(self, grads_and_vars, global_step=None, name=None):
  1287. if self._step == 1:
  1288. return self._optimizer.apply_gradients(
  1289. grads_and_vars, global_step, name=name)
  1290. grads, var_list = list(zip(*grads_and_vars))
  1291. def _pass_gradients():
  1292. return tf.group(*grads)
  1293. def _apply_gradients():
  1294. op = self._optimizer.apply_gradients(
  1295. zip(grads, var_list), global_step, name)
  1296. with tf.control_dependencies([op]):
  1297. zero_ops = []
  1298. for var in var_list:
  1299. grad_acc = self.get_slot(var, 'grad_acc')
  1300. zero_ops.append(
  1301. grad_acc.assign(
  1302. tf.zeros_like(grad_acc),
  1303. use_locking=self._use_locking))
  1304. zero_op = tf.group(*zero_ops)
  1305. return tf.group(*[op, zero_op])
  1306. iter_var = self._get_non_slot_variable('iter', tf.get_default_graph())
  1307. update_op = tf.cond(
  1308. tf.equal(iter_var, 0), _apply_gradients, _pass_gradients)
  1309. with tf.control_dependencies([update_op]):
  1310. iter_op = iter_var.assign(
  1311. tf.mod(iter_var + 1, self._step_t),
  1312. use_locking=self._use_locking)
  1313. return tf.group(*[update_op, iter_op])
  1314. def shard_features(x, num_datashards):
  1315. x = tf.convert_to_tensor(x)
  1316. batch_size = tf.shape(x)[0]
  1317. size_splits = []
  1318. with tf.device('/cpu:0'):
  1319. for i in range(num_datashards):
  1320. size_splits.append(
  1321. tf.cond(
  1322. tf.greater(
  1323. tf.compat.v1.mod(batch_size, num_datashards),
  1324. i), lambda: batch_size // num_datashards + 1,
  1325. lambda: batch_size // num_datashards))
  1326. return tf.split(x, size_splits, axis=0)