generator.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import torch
  3. class TextGenerator(object):
  4. def __init__(self,
  5. model,
  6. vocab,
  7. symbols,
  8. global_scorer=None,
  9. logger=None,
  10. dump_beam=''):
  11. self.alpha = 0.6
  12. self.logger = logger
  13. self.cuda = (torch.cuda.device_count() > 0)
  14. self.model = model
  15. # TODO generator
  16. self.vocab = vocab
  17. self.symbols = symbols
  18. self.start_token = 101 # ['[PAD]']
  19. self.end_token = 102 # '[PAD]']
  20. self.global_scorer = global_scorer
  21. self.beam_size = 5
  22. self.min_length = 5
  23. self.max_length = 384
  24. self.dump_beam = dump_beam
  25. # for debugging
  26. self.beam_trace = self.dump_beam != ''
  27. self.beam_accum = None
  28. if self.beam_trace:
  29. self.beam_accum = {
  30. 'predicted_ids': [],
  31. 'beam_parent_ids': [],
  32. 'scores': [],
  33. 'log_probs': []
  34. }
  35. def _build_target_tokens(self, pred):
  36. tokens = []
  37. for tok in pred:
  38. tok = int(tok)
  39. tokens.append(tok)
  40. if tokens[-1] == self.end_token:
  41. tokens = tokens[:-1]
  42. break
  43. tokens = [t for t in tokens if t < len(self.vocab)]
  44. tokens = self.vocab.DecodeIds(tokens).split(' ')
  45. return tokens
  46. def tile(self, x, count, dim=0):
  47. """
  48. Tiles x on dimension dim count times.
  49. """
  50. perm = list(range(len(x.size())))
  51. if dim != 0:
  52. perm[0], perm[dim] = perm[dim], perm[0]
  53. x = x.permute(perm).contiguous()
  54. out_size = list(x.size())
  55. out_size[0] *= count
  56. batch = x.size(0)
  57. x = x.view(batch, -1) \
  58. .transpose(0, 1) \
  59. .repeat(count, 1) \
  60. .transpose(0, 1) \
  61. .contiguous() \
  62. .view(*out_size)
  63. if dim != 0:
  64. x = x.permute(perm).contiguous()
  65. return x
  66. def translate_batch(self, encoder_inputs, fast=False):
  67. with torch.no_grad():
  68. return self._fast_translate_batch(
  69. encoder_inputs, self.max_length, min_length=self.min_length)
  70. def _fast_translate_batch(self, encoder_inputs, max_length, min_length=0):
  71. assert not self.dump_beam
  72. beam_size = self.beam_size
  73. tokens, types, padding_mask = encoder_inputs
  74. batch_size = tokens.size(0)
  75. device = tokens.device
  76. tmp_alive_seq = torch.full([batch_size, 1],
  77. self.start_token,
  78. dtype=torch.long,
  79. device=device)
  80. prediction_scores, dec_feat_seq, sequence_output = self.model(
  81. tokens,
  82. types,
  83. padding_mask,
  84. tmp_alive_seq,
  85. None,
  86. None,
  87. checkpoint_activations=False,
  88. is_infer=True,
  89. parallel_output=False,
  90. sequence_output=None)
  91. src_features = sequence_output
  92. src_features = self.tile(src_features, beam_size, dim=0)
  93. attention_mask = self.tile(padding_mask, beam_size, dim=0)
  94. batch_offset = torch.arange(
  95. batch_size, dtype=torch.long, device=device)
  96. beam_offset = torch.arange(
  97. 0,
  98. batch_size * beam_size,
  99. step=beam_size,
  100. dtype=torch.long,
  101. device=device)
  102. alive_seq = torch.full([batch_size * beam_size, 1],
  103. self.start_token,
  104. dtype=torch.long,
  105. device=device)
  106. # Give full probability to the first beam on the first step.
  107. topk_log_probs = (
  108. torch.tensor(
  109. [0.0] + [float('-inf')] * (beam_size - 1),
  110. device=device).repeat(batch_size))
  111. # Structure that holds finished hypotheses.
  112. hypotheses = [[] for _ in range(batch_size)] # noqa: F812
  113. results = {}
  114. results['predictions'] = [[] for _ in range(batch_size)] # noqa: F812
  115. results['scores'] = [[] for _ in range(batch_size)] # noqa: F812
  116. results['gold_score'] = [0] * batch_size
  117. results['batch'] = []
  118. dec_attn_mask = None
  119. dec_position_ids = None
  120. for step in range(max_length):
  121. prediction_scores, dec_feat_seq, _ = self.model(
  122. tokens,
  123. types,
  124. attention_mask,
  125. alive_seq,
  126. dec_position_ids,
  127. dec_attn_mask,
  128. checkpoint_activations=False,
  129. is_infer=True,
  130. parallel_output=False,
  131. sequence_output=src_features)
  132. dec_feat_seq = dec_feat_seq[:, -1, :]
  133. vocab_size = dec_feat_seq.size(-1)
  134. log_probs = torch.log(
  135. torch.softmax(dec_feat_seq.view(-1, vocab_size), dim=-1))
  136. if step < min_length:
  137. log_probs[:, self.end_token] = -1e20
  138. log_probs += topk_log_probs.view(-1).unsqueeze(1)
  139. alpha = self.alpha # global_scorer.alpha
  140. length_penalty = ((5.0 + (step + 1)) / 6.0)**alpha
  141. curr_scores = log_probs / length_penalty
  142. curr_scores = curr_scores.reshape(-1, beam_size * vocab_size)
  143. topk_scores, topk_ids = curr_scores.topk(beam_size, dim=-1)
  144. topk_log_probs = topk_scores * length_penalty
  145. # Resolve beam origin and true word ids.
  146. topk_beam_index = topk_ids.div(vocab_size, rounding_mode='trunc')
  147. topk_ids = topk_ids.fmod(vocab_size)
  148. # Map beam_index to batch_index in the flat representation.
  149. batch_index = (
  150. topk_beam_index
  151. + beam_offset[:topk_beam_index.size(0)].unsqueeze(1))
  152. select_indices = batch_index.view(-1)
  153. # Append last prediction.
  154. alive_seq = torch.cat([
  155. alive_seq.index_select(0, select_indices),
  156. topk_ids.view(-1, 1)
  157. ], -1)
  158. is_finished = topk_ids.eq(self.end_token)
  159. if step + 1 == max_length:
  160. is_finished.fill_(1) # self.end_token)
  161. # End condition is top beam is finished.
  162. end_condition = is_finished[:, 0].eq(1) # self.end_token)
  163. # Save finished hypotheses.
  164. if is_finished.any():
  165. predictions = alive_seq.view(-1, beam_size, alive_seq.size(-1))
  166. for i in range(is_finished.size(0)):
  167. b = batch_offset[i]
  168. if end_condition[i]:
  169. is_finished[i].fill_(1) # self.end_token)
  170. finished_hyp = is_finished[i].nonzero().view(-1)
  171. # Store finished hypotheses for this batch.
  172. for j in finished_hyp:
  173. hypotheses[b].append(
  174. (topk_scores[i, j], predictions[i, j, 1:]))
  175. # If the batch reached the end, save the n_best hypotheses.
  176. if end_condition[i]:
  177. best_hyp = sorted(
  178. hypotheses[b], key=lambda x: x[0], reverse=True)
  179. score, pred = best_hyp[0]
  180. results['scores'][b].append(score)
  181. results['predictions'][b].append(pred)
  182. non_finished = end_condition.eq(0).nonzero().view(-1)
  183. # If all sentences are translated, no need to go further.
  184. if len(non_finished) == 0:
  185. break
  186. # Remove finished batches for the next step.
  187. topk_log_probs = topk_log_probs.index_select(0, non_finished)
  188. batch_index = batch_index.index_select(0, non_finished)
  189. batch_offset = batch_offset.index_select(0, non_finished)
  190. alive_seq = predictions.index_select(0, non_finished) \
  191. .view(-1, alive_seq.size(-1))
  192. # Reorder states.
  193. select_indices = batch_index.view(-1)
  194. src_features = src_features.index_select(0, select_indices)
  195. attention_mask = attention_mask.index_select(0, select_indices)
  196. return results