distributed_plug.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from typing import Dict
  3. import torch
  4. from megatron_util import mpu, print_rank_0
  5. from megatron_util.fp16 import FP16_Module
  6. from torch.nn import functional as F
  7. from modelscope.models import TorchModel
  8. from modelscope.models.base import Tensor
  9. from modelscope.utils.logger import get_logger
  10. from modelscope.utils.megatron_utils import init_megatron_util
  11. from modelscope.utils.nlp.load_checkpoint import pre_load
  12. from . import PlugModel
  13. from .configuration import PlugNLGConfig
  14. logger = get_logger()
  15. class DistributedPlug(TorchModel):
  16. """
  17. The wrapper class of PLUG Model to initialize parallel environment, load model weights, generate sentences.
  18. Parameters:
  19. model_dir (`str`, *required*):
  20. Path to model damo/nlp_plug_text-generation_27B.
  21. The model structure in model_dir should be like this:
  22. model_dir
  23. |_ config.json
  24. |_ configuration.json
  25. |_ ds_zero-offload_10B_config.json
  26. |_ vocab.txt
  27. |_ model <-- an empty directory
  28. Model binaries shall be downloaded separately to populate the model directory, so that
  29. the model directory would contain the following binaries:
  30. |_ model
  31. |_ mp_rank_00_model_states.pt
  32. |_ mp_rank_01_model_states.pt
  33. |_ mp_rank_02_model_states.pt
  34. |_ mp_rank_03_model_states.pt
  35. |_ mp_rank_04_model_states.pt
  36. |_ mp_rank_05_model_states.pt
  37. |_ mp_rank_06_model_states.pt
  38. |_ mp_rank_07_model_states.pt
  39. rank (`int`, *required*):
  40. Used to identify different GPUs in a tensor parallel environment. eg. The rank of GPU #0 is 0, and the
  41. model file `mp_rank_00_model_states.pt` will be loaded on this GPU.
  42. world_size (`int`, *required*, defaults to 8):
  43. The parallel size in total.
  44. model_parallel_size (`int`, *required*, defaults to 8):
  45. The parallel size of model(tensor parallel).
  46. master_ip (`str`, *required*):
  47. The master IP, can usually be set to `"127.0.0.1"`, used as part of
  48. [`~torch.distributed.init_process_group`] method parameter `init_method`.
  49. `init_method` = `"tcp://{master_ip}:{master_port}"`
  50. master_port (`str`, *required*):
  51. The master port, can usually be set to `"29500"`, used as part of
  52. [`~torch.distributed.init_process_group`] method parameter `init_method`.
  53. `init_method` = `"tcp://{master_ip}:{master_port}"`
  54. seed (`int`, *optional*, defaults to 42):
  55. Random seed to control sampling.
  56. """
  57. def __init__(self, model_dir, rank, **kwargs):
  58. super().__init__(model_dir, **kwargs)
  59. self.rank = rank
  60. self.model_cfg = kwargs
  61. self.config = PlugNLGConfig.from_pretrained(model_dir)
  62. init_megatron_util(model_dir=model_dir, rank=rank)
  63. self.iteration = 0
  64. self.model = self.initialize_model(path_load_tag='model')
  65. def initialize_model(self, path_load_tag='model'):
  66. """Build the model."""
  67. print_rank_0('Building Plug model. It will take a few minutes ...')
  68. model = PlugModel(self.config)
  69. if mpu.get_data_parallel_rank() == 0:
  70. logger.info(
  71. ' > number of parameters on model parallel rank {}: {}'.format(
  72. mpu.get_tensor_model_parallel_rank(),
  73. sum([p.nelement() for p in model.parameters()])))
  74. if self.config.deepspeed and self.config.fp16:
  75. model.half()
  76. # GPU allocation.
  77. model.cuda(torch.cuda.current_device())
  78. # Fp16 conversion.
  79. if self.config.fp16:
  80. model = FP16_Module(model)
  81. if self.config.fp32_embedding:
  82. model.module.model.bert.embeddings.word_embeddings.float()
  83. model.module.model.bert.embeddings.position_embeddings.float()
  84. model.module.model.bert.embeddings.token_type_embeddings.float(
  85. )
  86. if self.config.fp32_tokentypes:
  87. model.module.model.bert.embeddings.token_type_embeddings.float(
  88. )
  89. if self.config.fp32_layernorm:
  90. for name, _module in model.named_modules():
  91. if 'LayerNorm' in name:
  92. _module.float()
  93. load_model = pre_load(
  94. mpu.get_tensor_model_parallel_rank(),
  95. self.model_dir,
  96. tag=path_load_tag)
  97. model_dict = model.module.model.state_dict()
  98. for key in load_model:
  99. if key not in model_dict.keys():
  100. print_rank_0('Skip key: ' + key)
  101. else:
  102. print_rank_0('Loading key: ' + key)
  103. model.module.model.load_state_dict(load_model, strict=False)
  104. return model
  105. @staticmethod
  106. def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
  107. # This function has been mostly taken from huggingface conversational ai code at
  108. # https://medium.com/huggingface/how-to-build-a-state-of-the-art-
  109. # conversational-ai-with-transfer-learning-2d818ac26313
  110. if top_k > 0:
  111. # Remove all tokens with a probability less than the last token of the top-k
  112. indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1,
  113. None]
  114. logits[indices_to_remove] = filter_value
  115. if top_p > 0.0:
  116. # convert to 1D
  117. logits = logits.view(logits.size()[1]).contiguous()
  118. sorted_logits, sorted_indices = torch.sort(logits, descending=True)
  119. cumulative_probs = torch.cumsum(
  120. F.softmax(sorted_logits, dim=-1), dim=-1)
  121. # Remove tokens with cumulative probability above the threshold
  122. sorted_indices_to_remove = cumulative_probs > top_p
  123. # Shift the indices to the right to keep also the first token above the threshold
  124. sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
  125. ..., :-1].clone()
  126. sorted_indices_to_remove[..., 0] = 0
  127. indices_to_remove = sorted_indices[sorted_indices_to_remove]
  128. logits[indices_to_remove] = filter_value
  129. # going back to 2D
  130. logits = logits.view(1, -1).contiguous()
  131. return logits
  132. def forward(self,
  133. input_tokens,
  134. token_type_ids=None,
  135. attention_mask=None,
  136. target_tokens=None,
  137. position_ids=None,
  138. decode_attention_mask=None,
  139. checkpoint_activations=False,
  140. is_infer=False,
  141. sequence_output=None,
  142. parallel_output=True):
  143. return self.model(
  144. input_tokens,
  145. token_type_ids,
  146. attention_mask,
  147. target_tokens,
  148. position_ids,
  149. decode_attention_mask,
  150. checkpoint_activations=checkpoint_activations,
  151. is_infer=is_infer,
  152. sequence_output=sequence_output,
  153. parallel_output=parallel_output)
  154. def generate(self, input: Dict[str, Tensor], out_length=128, *kwargs):
  155. device = torch.cuda.current_device()
  156. batch_size = input['input_ids'].shape[0]
  157. tokens = input['input_ids'].view(1, -1).contiguous().to(device)
  158. dec_input_ids = input['dec_input_ids'].to(device)
  159. attention_mask = input['attention_mask'].to(device)
  160. self.model.eval()
  161. with torch.no_grad():
  162. # Only supports batch_size=1
  163. all_generate_tokens = []
  164. generate_tokens = []
  165. counter = 0
  166. sequence_output = None
  167. vocab_size = self.config.original_vocab_size
  168. sep_token_idx = 102 # index of [SEP] token in BertTokenizer
  169. while counter < out_length:
  170. if counter % 128 == 0 and counter != 0:
  171. # Sliding window
  172. generate_tokens.append(sep_token_idx)
  173. start = (tokens == sep_token_idx).nonzero(
  174. as_tuple=True)[-1]
  175. if start + len(generate_tokens) >= 512:
  176. tokens = torch.cat([
  177. tokens[:start],
  178. torch.cuda.LongTensor(generate_tokens)
  179. ], -1)[-512:]
  180. else:
  181. tokens[0][start:start + len(generate_tokens
  182. )] = torch.cuda.LongTensor(
  183. generate_tokens)
  184. attention_mask = (tokens != 0)
  185. dec_input_ids = input['dec_input_ids'].to(device)
  186. generate_tokens = []
  187. sequence_output = None
  188. position_ids = torch.full([batch_size, 1],
  189. len(generate_tokens),
  190. dtype=torch.long,
  191. device=device)
  192. _, logits, sequence_output = self.model(
  193. tokens,
  194. None,
  195. attention_mask,
  196. dec_input_ids,
  197. attention_mask,
  198. position_ids,
  199. is_infer=True,
  200. sequence_output=sequence_output,
  201. parallel_output=False)
  202. logits = logits[:, -1, :]
  203. logits = logits / self.model_cfg['temperature']
  204. logits = self.top_k_logits(
  205. logits,
  206. top_k=self.model_cfg['top_k'],
  207. top_p=self.model_cfg['top_p'])
  208. log_probs = F.softmax(logits, dim=-1)
  209. prev = torch.multinomial(log_probs, num_samples=1)
  210. prev_token = prev[0].item()
  211. if prev_token >= vocab_size:
  212. prev_token = 100
  213. prev[0] = 100
  214. if prev_token == 102 and len(all_generate_tokens) > int(
  215. max(1, out_length) * 0.8):
  216. break
  217. if prev_token == 102:
  218. counter += 1
  219. continue
  220. dec_input_ids = torch.cat([dec_input_ids, prev], dim=1)
  221. generate_tokens.append(prev_token)
  222. all_generate_tokens.append(prev_token)
  223. counter += 1
  224. generate_context = []
  225. for token in all_generate_tokens:
  226. if generate_context and generate_context[
  227. -1] == 100 and token == 100:
  228. continue
  229. else:
  230. generate_context.append(token)
  231. return {'generate_context': generate_context}