inference.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301
  1. # Copyright (c) 2022 Zhipu.AI
  2. from typing import List
  3. import torch
  4. import torch.nn.functional as F
  5. def get_ltor_masks_and_position_ids(
  6. data,
  7. eod_token,
  8. reset_position_ids,
  9. reset_attention_mask,
  10. ):
  11. """Build masks and position id for left to right model."""
  12. # Extract batch size and sequence length.
  13. micro_batch_size, seq_length = data.size()
  14. # Attention mask (lower triangular).
  15. if reset_attention_mask:
  16. att_mask_batch = micro_batch_size
  17. else:
  18. att_mask_batch = 1
  19. attention_mask = torch.tril(
  20. torch.ones((att_mask_batch, seq_length, seq_length),
  21. device=data.device)).view(att_mask_batch, 1, seq_length,
  22. seq_length)
  23. # Position ids.
  24. position_ids = torch.arange(
  25. seq_length, dtype=torch.long, device=data.device)
  26. position_ids = position_ids.unsqueeze(0).expand_as(data)
  27. # We need to clone as the ids will be modified based on batch index.
  28. if reset_position_ids:
  29. position_ids = position_ids.clone()
  30. if reset_position_ids or reset_attention_mask:
  31. # Loop through the batches:
  32. for b in range(micro_batch_size):
  33. # Find indices where EOD token is.
  34. eod_index = position_ids[b, data[b] == eod_token]
  35. # Detach indices from positions if going to modify positions.
  36. if reset_position_ids:
  37. eod_index = eod_index.clone()
  38. # Loop through EOD indices:
  39. prev_index = 0
  40. for j in range(eod_index.size()[0]):
  41. i = eod_index[j]
  42. # Mask attention loss.
  43. if reset_attention_mask:
  44. attention_mask[b, 0, (i + 1):, :(i + 1)] = 0
  45. # Reset positions.
  46. if reset_position_ids:
  47. position_ids[b, (i + 1):] -= i + 1 - prev_index
  48. prev_index = i + 1
  49. # Convert attention mask to binary:
  50. attention_mask = attention_mask < 0.5
  51. return attention_mask, position_ids
  52. def get_batch(
  53. context_tokens,
  54. micro_batch_size,
  55. eod_token,
  56. reset_position_ids=False,
  57. reset_attention_mask=False,
  58. ):
  59. """Generate batch from context tokens."""
  60. tokens = context_tokens.view(micro_batch_size, -1).contiguous().cuda()
  61. # Get the attention mask and position ids.
  62. attention_mask, position_ids = get_ltor_masks_and_position_ids(
  63. tokens,
  64. eod_token,
  65. reset_position_ids,
  66. reset_attention_mask,
  67. )
  68. return tokens, attention_mask, position_ids
  69. def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
  70. """This function has been mostly taken from huggingface conversational
  71. ai code at
  72. https://medium.com/huggingface/how-to-build-a-state-of-the-art-
  73. conversational-ai-with-transfer-learning-2d818ac26313"""
  74. if top_k > 0:
  75. # Remove all tokens with a probability less than the
  76. # last token of the top-k
  77. indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1,
  78. None]
  79. logits[indices_to_remove] = filter_value
  80. if top_p > 0.0:
  81. # Cconvert to 1D
  82. sorted_logits, sorted_indices = torch.sort(
  83. logits, descending=True, dim=-1)
  84. cumulative_probs = torch.cumsum(
  85. F.softmax(sorted_logits, dim=-1), dim=-1)
  86. # Remove tokens with cumulative probability above the threshold
  87. sorted_indices_to_remove = cumulative_probs > top_p
  88. # Shift the indices to the right to keep also the first token
  89. # above the threshold
  90. sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
  91. ..., :-1].clone()
  92. sorted_indices_to_remove[..., 0] = 0
  93. for i in range(sorted_indices.size(0)):
  94. indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
  95. logits[i][indices_to_remove] = filter_value
  96. return logits
  97. def pad_batch(batch, pad_id, seq_length):
  98. context_lengths = []
  99. for tokens in batch:
  100. context_length = len(tokens)
  101. if context_length < seq_length:
  102. tokens.extend([pad_id] * (seq_length - context_length))
  103. context_lengths.append(context_length)
  104. return batch, context_lengths
  105. def get_token_stream(
  106. model,
  107. tokenizer,
  108. seq_length,
  109. out_seq_length,
  110. context_tokens,
  111. return_scores: bool = False,
  112. prompt_length: int = None,
  113. micro_batch_size: int = None,
  114. bad_ids: List = None,
  115. temperature: float = 1.0,
  116. topp: float = 1.0,
  117. topk: int = 0.0,
  118. greedy: bool = False,
  119. ):
  120. context_tokens, context_lengths = pad_batch(context_tokens,
  121. tokenizer.eos_token_id,
  122. seq_length)
  123. context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
  124. context_length_tensor = torch.cuda.LongTensor(context_lengths)
  125. context_length = context_length_tensor.min().item()
  126. tokens, attention_mask, position_ids = get_batch(
  127. context_tokens_tensor,
  128. micro_batch_size,
  129. tokenizer.eos_token_id,
  130. )
  131. batch_token_iterator = sample_sequence_batch(
  132. model,
  133. tokenizer,
  134. context_tokens_tensor,
  135. context_length_tensor,
  136. attention_mask,
  137. position_ids,
  138. seq_length=seq_length,
  139. out_seq_length=out_seq_length,
  140. return_scores=return_scores,
  141. prompt_length=prompt_length,
  142. bad_ids=bad_ids,
  143. temperature=temperature,
  144. topp=topp,
  145. topk=topk,
  146. greedy=greedy,
  147. )
  148. for tokens, lengths in batch_token_iterator:
  149. context_length += 1
  150. if tokens is not None:
  151. yield tokens[:, :context_length], lengths
  152. else:
  153. yield None, None
  154. def switch(val1, val2, boolean):
  155. boolean = boolean.type_as(val1)
  156. return (1 - boolean) * val1 + boolean * val2
  157. def sample_sequence_batch(
  158. model,
  159. tokenizer,
  160. context_tokens,
  161. context_lengths,
  162. attention_mask,
  163. position_ids,
  164. seq_length,
  165. out_seq_length,
  166. maxlen=None,
  167. return_scores: bool = False,
  168. prompt_length: int = None,
  169. bad_ids: List = None,
  170. temperature: float = 1.0,
  171. topp: float = 1.0,
  172. topk: int = 0.0,
  173. recompute: bool = False,
  174. greedy: bool = False,
  175. ):
  176. model.eval()
  177. with torch.no_grad():
  178. context_length = context_lengths.min().item()
  179. eos_id = tokenizer.eos_token_id
  180. counter = 0
  181. org_context_length = context_length
  182. layer_past = None
  183. batch_size = context_tokens.size(0)
  184. is_done = torch.zeros([batch_size]).byte().cuda()
  185. tokens = context_tokens
  186. if maxlen is None:
  187. maxlen = seq_length - 1
  188. if maxlen > (org_context_length + out_seq_length):
  189. maxlen = org_context_length + out_seq_length
  190. lengths = torch.ones([batch_size]).long().cuda() * maxlen
  191. if return_scores:
  192. scores = torch.zeros([batch_size]).float().cuda()
  193. while context_length <= (maxlen):
  194. if recompute:
  195. logits = model(
  196. tokens,
  197. position_ids,
  198. attention_mask,
  199. prompt_length=prompt_length,
  200. context_length=context_length,
  201. )
  202. logits = logits[:, context_length - 1, :]
  203. else:
  204. if counter == 0:
  205. tokens2use = tokens[:, :context_length]
  206. positions2use = position_ids[:, :context_length]
  207. else:
  208. tokens2use = tokens[:, context_length - 1].view(
  209. batch_size, -1)
  210. positions2use = position_ids[:, context_length - 1].view(
  211. batch_size, -1)
  212. logits, layer_past = model(
  213. tokens2use,
  214. positions2use,
  215. attention_mask,
  216. layer_past=layer_past,
  217. get_key_value=True,
  218. prompt_length=prompt_length,
  219. context_length=context_length,
  220. )
  221. logits = logits[:, -1].view(batch_size, -1).contiguous()
  222. if bad_ids is not None:
  223. for bad_id in bad_ids:
  224. logits[:, bad_id] = -10000
  225. if greedy:
  226. prev = torch.argmax(logits, dim=-1).view(-1)
  227. else:
  228. logits = logits.float()
  229. if return_scores:
  230. orig_log_probs = torch.log_softmax(logits, dim=-1)
  231. logits /= temperature
  232. logits = top_k_logits(logits, top_k=topk, top_p=topp)
  233. log_probs = F.softmax(logits, dim=-1)
  234. prev = torch.multinomial(log_probs, num_samples=1).view(-1)
  235. started = context_lengths <= context_length
  236. new_tokens = switch(tokens[:, context_length].view(-1), prev,
  237. started)
  238. if not greedy and return_scores:
  239. indices = prev.view(-1, 1)
  240. new_scores = orig_log_probs.gather(1, indices).view(-1)
  241. new_scores = new_scores * started
  242. new_scores = new_scores * is_done.bool().logical_not()
  243. scores += new_scores
  244. tokens[:, context_length] = new_tokens
  245. done_token = (prev == eos_id).byte() & started.byte()
  246. just_finished = (done_token & ~is_done).bool()
  247. lengths[just_finished.view(-1)] = context_length
  248. is_done = is_done | done_token
  249. done = torch.all(is_done)
  250. if return_scores:
  251. yield tokens, (lengths, scores)
  252. else:
  253. yield tokens, lengths
  254. context_length += 1
  255. counter += 1
  256. if done:
  257. break