distributed_gpt3.py 51 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349
  1. # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
  2. # Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import math
  16. import os
  17. from collections import OrderedDict
  18. from typing import Callable, Dict, List, Optional, Union
  19. import torch
  20. from megatron_util import get_args, mpu
  21. from megatron_util.global_vars import get_global_memory_buffer
  22. from megatron_util.model import (AttnMaskType, Float16Module, LayerNorm,
  23. bias_gelu_impl)
  24. from megatron_util.model.fused_softmax import FusedScaleMaskSoftmax
  25. from torch import nn
  26. from torch.nn import functional as F
  27. from transformers.modeling_utils import PreTrainedModel
  28. from modelscope.models import TorchModel
  29. from modelscope.models.nlp.gpt3 import GPT3Config
  30. from modelscope.outputs import TextGenerationModelOutput, TokenGeneratorOutput
  31. from modelscope.utils.megatron_utils import init_megatron_util
  32. from modelscope.utils.nlp.load_checkpoint import pre_load
  33. from modelscope.utils.streaming_output import StreamingOutputMixin
  34. class GPT3ParallelMLP(nn.Module):
  35. """MLP.
  36. MLP will take the input with h hidden state, project it to 4*h
  37. hidden dimension, perform nonlinear transformation, and project the
  38. state back into h hidden dimension.
  39. """
  40. def __init__(self, config, init_method, output_layer_init_method):
  41. super().__init__()
  42. # Project to 4h.
  43. self.dense_h_to_4h = mpu.ColumnParallelLinear(
  44. config.hidden_size,
  45. config.ffn_hidden_size,
  46. gather_output=False,
  47. init_method=init_method,
  48. skip_bias_add=True)
  49. self.bias_gelu_fusion = config.bias_gelu_fusion
  50. self.activation_func = F.gelu
  51. # Project back to h.
  52. self.dense_4h_to_h = mpu.RowParallelLinear(
  53. config.ffn_hidden_size,
  54. config.hidden_size,
  55. input_is_parallel=True,
  56. init_method=output_layer_init_method,
  57. skip_bias_add=True)
  58. def forward(self, hidden_states):
  59. # [s, b, 4hp]
  60. intermediate_parallel, bias_parallel = self.dense_h_to_4h(
  61. hidden_states)
  62. if self.bias_gelu_fusion:
  63. intermediate_parallel = \
  64. bias_gelu_impl(intermediate_parallel, bias_parallel)
  65. else:
  66. intermediate_parallel = \
  67. self.activation_func(intermediate_parallel + bias_parallel)
  68. # [s, b, h]
  69. output, output_bias = self.dense_4h_to_h(intermediate_parallel)
  70. return output, output_bias
  71. class GPT3Embedding(nn.Module):
  72. """Language model embeddings.
  73. Arguments:
  74. hidden_size: hidden size
  75. vocab_size: vocabulary size
  76. max_sequence_length: maximum size of sequence. This
  77. is used for positional embedding
  78. embedding_dropout_prob: dropout probability for embeddings
  79. init_method: weight initialization method
  80. num_tokentypes: size of the token-type embeddings. 0 value
  81. will ignore this embedding
  82. """
  83. def __init__(self, config, init_method):
  84. super().__init__()
  85. self.hidden_size = config.hidden_size
  86. self.init_method = init_method
  87. # Word embeddings (parallel).
  88. self.word_embeddings = mpu.VocabParallelEmbedding(
  89. config.vocab_size, self.hidden_size, init_method=self.init_method)
  90. # Position embedding (serial).
  91. self.position_embeddings = nn.Embedding(config.max_position_embeddings,
  92. self.hidden_size)
  93. # Initialize the position embeddings.
  94. self.init_method(self.position_embeddings.weight)
  95. self.fp32_residual_connection = config.fp32_residual_connection
  96. self.sequence_parallel = config.sequence_parallel
  97. # Embeddings dropout
  98. self.embedding_dropout = nn.Dropout(config.hidden_dropout)
  99. def zero_parameters(self):
  100. """Zero out all parameters in embedding."""
  101. self.word_embeddings.weight.data.fill_(0)
  102. self.word_embeddings.weight.shared = True
  103. self.position_embeddings.weight.data.fill_(0)
  104. self.position_embeddings.weight.shared = True
  105. def forward(self, input_ids, position_ids):
  106. # Embeddings.
  107. words_embeddings = self.word_embeddings(input_ids)
  108. position_embeddings = self.position_embeddings(position_ids)
  109. embeddings = words_embeddings + position_embeddings
  110. # Data format change to avoid explicit transposes : [b s h] --> [s b h].
  111. embeddings = embeddings.transpose(0, 1).contiguous()
  112. # If the input flag for fp32 residual connection is set, convert for float.
  113. if self.fp32_residual_connection:
  114. embeddings = embeddings.float()
  115. # Dropout.
  116. if self.sequence_parallel:
  117. embeddings = mpu.scatter_to_sequence_parallel_region(embeddings)
  118. with mpu.get_cuda_rng_tracker().fork():
  119. embeddings = self.embedding_dropout(embeddings)
  120. else:
  121. embeddings = self.embedding_dropout(embeddings)
  122. return embeddings
  123. class NoopTransformerLayer(nn.Module):
  124. def __init__(self, layer_number):
  125. super().__init__()
  126. self.layer_number = layer_number
  127. def forward(self,
  128. hidden_states,
  129. attention_mask,
  130. encoder_output=None,
  131. enc_dec_attn_mask=None,
  132. inference_params=None):
  133. return hidden_states.clone()
  134. def attention_mask_func(attention_scores, attention_mask):
  135. attention_scores.masked_fill_(attention_mask, -10000.0)
  136. return attention_scores
  137. class GPT3CoreAttention(nn.Module):
  138. def __init__(self,
  139. config,
  140. layer_number,
  141. attn_mask_type=AttnMaskType.padding):
  142. super().__init__()
  143. self.fp16 = config.fp16
  144. self.bf16 = config.bf16
  145. self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling
  146. self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
  147. if self.apply_query_key_layer_scaling:
  148. self.attention_softmax_in_fp32 = True
  149. self.layer_number = max(1, layer_number)
  150. self.attn_mask_type = attn_mask_type
  151. self.sequence_parallel = config.sequence_parallel
  152. projection_size = config.kv_channels * config.num_attention_heads
  153. # Per attention head and per partition values.
  154. world_size = mpu.get_tensor_model_parallel_world_size()
  155. self.hidden_size_per_partition = mpu.divide(projection_size,
  156. world_size)
  157. self.hidden_size_per_attention_head = mpu.divide(
  158. projection_size, config.num_attention_heads)
  159. self.num_attention_heads_per_partition = mpu.divide(
  160. config.num_attention_heads, world_size)
  161. coeff = None
  162. self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
  163. if self.apply_query_key_layer_scaling:
  164. coeff = self.layer_number
  165. self.norm_factor *= coeff
  166. self.scale_mask_softmax = FusedScaleMaskSoftmax(
  167. self.fp16, self.bf16, self.attn_mask_type,
  168. config.masked_softmax_fusion, attention_mask_func,
  169. self.attention_softmax_in_fp32, coeff)
  170. # Dropout. Note that for a single iteration, this layer will generate
  171. # different outputs on different number of parallel partitions but
  172. # on average it should not be partition dependent.
  173. self.attention_dropout = nn.Dropout(config.attention_dropout)
  174. def forward(self, query_layer, key_layer, value_layer, attention_mask):
  175. # ===================================
  176. # Raw attention scores. [b, np, s, s]
  177. # ===================================
  178. # [b, np, sq, sk]
  179. output_size = (query_layer.size(1), query_layer.size(2),
  180. query_layer.size(0), key_layer.size(0))
  181. # [sq, b, np, hn] -> [sq, b * np, hn]
  182. query_layer = query_layer.view(output_size[2],
  183. output_size[0] * output_size[1], -1)
  184. # [sk, b, np, hn] -> [sk, b * np, hn]
  185. key_layer = key_layer.view(output_size[3],
  186. output_size[0] * output_size[1], -1)
  187. # preallocting input tensor: [b * np, sq, sk]
  188. matmul_input_buffer = get_global_memory_buffer().get_tensor(
  189. (output_size[0] * output_size[1], output_size[2], output_size[3]),
  190. query_layer.dtype, 'mpu')
  191. # Raw attention scores. [b * np, sq, sk]
  192. matmul_result = torch.baddbmm(
  193. matmul_input_buffer,
  194. query_layer.transpose(0, 1), # [b * np, sq, hn]
  195. key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
  196. beta=0.0,
  197. alpha=(1.0 / self.norm_factor))
  198. # change view to [b, np, sq, sk]
  199. attention_scores = matmul_result.view(*output_size)
  200. # ===========================
  201. # Attention probs and dropout
  202. # ===========================
  203. # attention scores and attention mask [b, np, sq, sk]
  204. attention_probs = self.scale_mask_softmax(attention_scores,
  205. attention_mask)
  206. # This is actually dropping out entire tokens to attend to, which might
  207. # seem a bit unusual, but is taken from the original Transformer paper.
  208. if not self.sequence_parallel:
  209. with mpu.get_cuda_rng_tracker().fork():
  210. attention_probs = self.attention_dropout(attention_probs)
  211. else:
  212. attention_probs = self.attention_dropout(attention_probs)
  213. # =========================
  214. # Context layer. [sq, b, hp]
  215. # =========================
  216. # value_layer -> context layer.
  217. # [sk, b, np, hn] --> [b, np, sq, hn]
  218. # context layer shape: [b, np, sq, hn]
  219. output_size = (value_layer.size(1), value_layer.size(2),
  220. query_layer.size(0), value_layer.size(3))
  221. # change view [sk, b * np, hn]
  222. value_layer = value_layer.view(
  223. value_layer.size(0), output_size[0] * output_size[1], -1)
  224. # change view [b * np, sq, sk]
  225. attention_probs = attention_probs.view(output_size[0] * output_size[1],
  226. output_size[2], -1)
  227. # matmul: [b * np, sq, hn]
  228. context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
  229. # change view [b, np, sq, hn]
  230. context_layer = context_layer.view(*output_size)
  231. # [b, np, sq, hn] --> [sq, b, np, hn]
  232. context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
  233. # [sq, b, np, hn] --> [sq, b, hp]
  234. new_context_layer_shape = context_layer.size()[:-2] + \
  235. (self.hidden_size_per_partition,)
  236. context_layer = context_layer.view(*new_context_layer_shape)
  237. return context_layer
  238. class GPT3ParallelAttention(nn.Module):
  239. """Parallel self-attention layer abstract class.
  240. Self-attention layer takes input with size [s, b, h]
  241. and returns output of the same size.
  242. """
  243. def __init__(self, config, init_method, output_layer_init_method,
  244. layer_number):
  245. super().__init__()
  246. self.layer_number = max(1, layer_number)
  247. self.params_dtype = config.params_dtype
  248. projection_size = config.kv_channels * config.num_attention_heads
  249. # Per attention head and per partition values.
  250. world_size = mpu.get_tensor_model_parallel_world_size()
  251. self.hidden_size_per_attention_head = mpu.divide(
  252. projection_size, config.num_attention_heads)
  253. self.num_attention_heads_per_partition = mpu.divide(
  254. config.num_attention_heads, world_size)
  255. # Strided linear layer.
  256. self.query_key_value = mpu.ColumnParallelLinear(
  257. config.hidden_size,
  258. 3 * projection_size,
  259. gather_output=False,
  260. init_method=init_method)
  261. self.core_attention = GPT3CoreAttention(config, self.layer_number)
  262. # Output.
  263. self.dense = mpu.RowParallelLinear(
  264. projection_size,
  265. config.hidden_size,
  266. input_is_parallel=True,
  267. init_method=output_layer_init_method,
  268. skip_bias_add=True)
  269. def _allocate_memory(self, inference_max_sequence_len, batch_size):
  270. return torch.empty(
  271. inference_max_sequence_len,
  272. batch_size,
  273. self.num_attention_heads_per_partition,
  274. self.hidden_size_per_attention_head,
  275. dtype=self.params_dtype,
  276. device=torch.cuda.current_device())
  277. def forward(self, hidden_states, attention_mask, inference_params=None):
  278. # hidden_states: [sq, b, h]
  279. # =================================================
  280. # Pre-allocate memory for key-values for inference.
  281. # =================================================
  282. if inference_params:
  283. if self.layer_number not in inference_params.key_value_memory_dict:
  284. inf_max_seq_len = inference_params.max_sequence_len
  285. inf_max_batch_size = inference_params.max_batch_size
  286. inference_key_memory = self._allocate_memory(
  287. inf_max_seq_len, inf_max_batch_size)
  288. inference_value_memory = self._allocate_memory(
  289. inf_max_seq_len, inf_max_batch_size)
  290. inference_params.key_value_memory_dict[self.layer_number] = (
  291. inference_key_memory, inference_value_memory)
  292. else:
  293. inference_key_memory, inference_value_memory = \
  294. inference_params.key_value_memory_dict[self.layer_number]
  295. # =====================
  296. # Query, Key, and Value
  297. # =====================
  298. # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
  299. mixed_x_layer, _ = self.query_key_value(hidden_states)
  300. # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
  301. new_tensor_shape = mixed_x_layer.size()[:-1] + \
  302. (self.num_attention_heads_per_partition,
  303. 3 * self.hidden_size_per_attention_head)
  304. mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
  305. # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
  306. (query_layer, key_layer,
  307. value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3)
  308. # ==================================
  309. # Adjust key and value for inference
  310. # ==================================
  311. if inference_params:
  312. batch_start = inference_params.batch_size_offset
  313. batch_end = batch_start + key_layer.size(1)
  314. assert batch_end <= inference_key_memory.size(1)
  315. sequence_start = inference_params.sequence_len_offset
  316. sequence_end = sequence_start + key_layer.size(0)
  317. assert sequence_end <= inference_key_memory.size(0)
  318. # Copy key and values.
  319. inference_key_memory[sequence_start:sequence_end,
  320. batch_start:batch_end, ...] = key_layer
  321. inference_value_memory[sequence_start:sequence_end,
  322. batch_start:batch_end, ...] = value_layer
  323. key_layer = inference_key_memory[:sequence_end,
  324. batch_start:batch_end, ...]
  325. value_layer = inference_value_memory[:sequence_end,
  326. batch_start:batch_end, ...]
  327. # ==================================
  328. # core attention computation
  329. # ==================================
  330. context_layer = self.core_attention(query_layer, key_layer,
  331. value_layer, attention_mask)
  332. # =================
  333. # Output. [sq, b, h]
  334. # =================
  335. output, bias = self.dense(context_layer)
  336. return output, bias
  337. class nullcontext:
  338. def __init__(self, enter_result=None):
  339. self.enter_result = enter_result
  340. def __enter__(self):
  341. return self.enter_result
  342. def __exit__(self, *excinfo):
  343. pass
  344. def bias_dropout_add(x, bias, residual, prob, training):
  345. # type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
  346. out = F.dropout(x + bias, p=prob, training=training)
  347. out = residual + out
  348. return out
  349. def get_bias_dropout_add(training):
  350. def _bias_dropout_add(x, bias, residual, prob):
  351. return bias_dropout_add(x, bias, residual, prob, training)
  352. return _bias_dropout_add
  353. @torch.jit.script
  354. def bias_dropout_add_fused_train(x: torch.Tensor, bias: torch.Tensor,
  355. residual: torch.Tensor,
  356. prob: float) -> torch.Tensor:
  357. return bias_dropout_add(x, bias, residual, prob, True)
  358. @torch.jit.script
  359. def bias_dropout_add_fused_inference(x: torch.Tensor, bias: torch.Tensor,
  360. residual: torch.Tensor,
  361. prob: float) -> torch.Tensor:
  362. return bias_dropout_add(x, bias, residual, prob, False)
  363. class GPT3ParallelTransformerLayer(nn.Module):
  364. """A single transformer layer.
  365. Transformer layer takes input with size [s, b, h] and returns an
  366. output of the same size.
  367. """
  368. def __init__(self, config, init_method, output_layer_init_method,
  369. layer_number):
  370. super().__init__()
  371. self.layer_number = layer_number
  372. self.apply_residual_connection_post_layernorm \
  373. = config.apply_residual_connection_post_layernorm
  374. self.bf16 = config.bf16
  375. self.fp32_residual_connection = config.fp32_residual_connection
  376. # Layernorm on the input data.
  377. self.input_layernorm = LayerNorm(
  378. config.hidden_size,
  379. eps=config.layernorm_epsilon,
  380. no_persist_layer_norm=config.no_persist_layer_norm,
  381. sequence_parallel=config.sequence_parallel)
  382. # Self attention.
  383. self.self_attention = GPT3ParallelAttention(config, init_method,
  384. output_layer_init_method,
  385. layer_number)
  386. self.hidden_dropout = config.hidden_dropout
  387. self.bias_dropout_fusion = config.bias_dropout_fusion
  388. # Layernorm on the attention output
  389. self.post_attention_layernorm = LayerNorm(
  390. config.hidden_size,
  391. eps=config.layernorm_epsilon,
  392. no_persist_layer_norm=config.no_persist_layer_norm,
  393. sequence_parallel=config.sequence_parallel)
  394. # MLP
  395. self.mlp = GPT3ParallelMLP(config, init_method,
  396. output_layer_init_method)
  397. # Set bias+dropout+add fusion grad_enable execution handler.
  398. TORCH_MAJOR = int(torch.__version__.split('.')[0])
  399. TORCH_MINOR = int(torch.__version__.split('.')[1])
  400. use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1
  401. and TORCH_MINOR >= 10)
  402. self.bias_dropout_add_exec_handler = \
  403. nullcontext if use_nvfuser else torch.enable_grad
  404. def forward(self, hidden_states, attention_mask, inference_params=None):
  405. # hidden_states: [s, b, h]
  406. # Layer norm at the beginning of the transformer layer.
  407. layernorm_output = self.input_layernorm(hidden_states)
  408. # Self attention.
  409. attention_output, attention_bias = \
  410. self.self_attention(
  411. layernorm_output,
  412. attention_mask,
  413. inference_params=inference_params)
  414. # Residual connection.
  415. if self.apply_residual_connection_post_layernorm:
  416. residual = layernorm_output
  417. else:
  418. residual = hidden_states
  419. if self.bias_dropout_fusion:
  420. if self.training:
  421. bias_dropout_add_func = bias_dropout_add_fused_train
  422. else:
  423. bias_dropout_add_func = bias_dropout_add_fused_inference
  424. else:
  425. bias_dropout_add_func = get_bias_dropout_add(self.training)
  426. with self.bias_dropout_add_exec_handler():
  427. layernorm_input = bias_dropout_add_func(
  428. attention_output, attention_bias.expand_as(residual), residual,
  429. self.hidden_dropout)
  430. # Layer norm post the self attention.
  431. layernorm_output = self.post_attention_layernorm(layernorm_input)
  432. # MLP.
  433. mlp_output, mlp_bias = self.mlp(layernorm_output)
  434. # Second residual connection.
  435. if self.apply_residual_connection_post_layernorm:
  436. residual = layernorm_output
  437. else:
  438. residual = layernorm_input
  439. with self.bias_dropout_add_exec_handler():
  440. output = bias_dropout_add_func(mlp_output,
  441. mlp_bias.expand_as(residual),
  442. residual, self.hidden_dropout)
  443. # Jit compiled function creates 'view' tensor. This tensor
  444. # potentially gets saved in the MPU checkpoint function context,
  445. # which rejects view tensors. While making a viewless tensor here
  446. # won't result in memory savings (like the data loader, or
  447. # p2p_communication), it serves to document the origin of this
  448. # 'view' tensor.
  449. output = mpu.make_viewless_tensor(
  450. inp=output, requires_grad=output.requires_grad, keep_graph=True)
  451. return output
  452. class GPT3ParallelTransformer(nn.Module):
  453. """Transformer class."""
  454. def __init__(self,
  455. config,
  456. init_method,
  457. output_layer_init_method,
  458. post_layer_norm=True,
  459. pre_process=True,
  460. post_process=True):
  461. super().__init__()
  462. self.bf16 = config.bf16
  463. self.fp32_residual_connection = config.fp32_residual_connection
  464. self.post_layer_norm = post_layer_norm
  465. self.pre_process = pre_process
  466. self.post_process = post_process
  467. self.input_tensor = None
  468. self.sequence_parallel = config.sequence_parallel
  469. # Number of layers.
  470. self.num_layers = config.num_hidden_layers
  471. # Transformer layers.
  472. def build_layer(layer_number):
  473. return GPT3ParallelTransformerLayer(config, init_method,
  474. output_layer_init_method,
  475. layer_number)
  476. if self.num_layers == 0:
  477. self.num_layers = 1
  478. self.layers = torch.nn.ModuleList([NoopTransformerLayer(1)])
  479. else:
  480. self.layers = torch.nn.ModuleList(
  481. [build_layer(i + 1) for i in range(self.num_layers)])
  482. if self.post_process and self.post_layer_norm:
  483. # Final layer norm before output.
  484. self.final_layernorm = LayerNorm(
  485. config.hidden_size,
  486. eps=config.layernorm_epsilon,
  487. no_persist_layer_norm=config.no_persist_layer_norm,
  488. sequence_parallel=config.sequence_parallel)
  489. def _get_layer(self, layer_number):
  490. return self.layers[layer_number]
  491. def forward(self, hidden_states, attention_mask, inference_params=None):
  492. # hidden_states: [s, b, h]
  493. if not self.pre_process:
  494. # See set_input_tensor()
  495. hidden_states = self.input_tensor
  496. # Viewless tensor.
  497. # - We only need to create a viewless tensor in the case of micro batch
  498. # size (mbs) == 1, since in this case, 'hidden_states.transpose()'
  499. # above creates a view tensor, and '.contiguous()' is a pass-through.
  500. # For mbs >= 2, '.contiguous()' creates a new tensor, eliminating
  501. # the need to make it viewless.
  502. #
  503. # However, we don't explicitly check mbs == 1 here because
  504. # make_viewless_tensor() has negligible overhead when its input
  505. # is already viewless.
  506. #
  507. # - For the 'else' case above, calling make_viewless_tensor() here is
  508. # likely redundant, since p2p_communication.py (likely originator)
  509. # already creates viewless tensors. That said, make_viewless_tensor()
  510. # is called here to be future-proof and corner-case-proof.
  511. hidden_states = mpu.make_viewless_tensor(
  512. hidden_states,
  513. requires_grad=True,
  514. keep_graph=True,
  515. )
  516. if self.sequence_parallel:
  517. rng_context = mpu.get_cuda_rng_tracker().fork()
  518. else:
  519. rng_context = nullcontext()
  520. with rng_context:
  521. # Forward pass.
  522. for index in range(self.num_layers):
  523. layer = self._get_layer(index)
  524. hidden_states = layer(
  525. hidden_states,
  526. attention_mask,
  527. inference_params=inference_params)
  528. # Final layer norm.
  529. if self.post_process and self.post_layer_norm:
  530. hidden_states = self.final_layernorm(hidden_states)
  531. return hidden_states
  532. class GPT3TransformerLanguageModel(nn.Module):
  533. """Transformer language model.
  534. Arguments:
  535. transformer_hparams: transformer hyperparameters
  536. vocab_size: vocabulary size
  537. max_sequence_length: maximum size of sequence. This
  538. is used for positional embedding
  539. embedding_dropout_prob: dropout probability for embeddings
  540. num_tokentypes: size of the token-type embeddings. 0 value
  541. will ignore this embedding
  542. """
  543. def __init__(self, config, init_method, output_layer_init_method):
  544. super().__init__()
  545. self.hidden_size = config.hidden_size
  546. self.init_method = init_method
  547. self.encoder_hidden_state = None
  548. # Embeddings.
  549. self.embedding = GPT3Embedding(config, self.init_method)
  550. # Transformer.
  551. self.encoder = GPT3ParallelTransformer(
  552. config,
  553. self.init_method,
  554. output_layer_init_method,
  555. )
  556. def forward(self,
  557. enc_input_ids,
  558. enc_position_ids,
  559. enc_attn_mask,
  560. inference_params=None,
  561. enc_hidden_states=None):
  562. # Encoder embedding.
  563. encoder_input = self.embedding(enc_input_ids, enc_position_ids)
  564. # Run encoder.
  565. if enc_hidden_states is None:
  566. if self.encoder is not None:
  567. encoder_output = self.encoder(
  568. encoder_input,
  569. enc_attn_mask,
  570. inference_params=inference_params)
  571. else:
  572. encoder_output = self.encoder_hidden_state
  573. else:
  574. encoder_output = enc_hidden_states.to(encoder_input.dtype)
  575. return encoder_output
  576. def init_method_normal(sigma):
  577. """Init method based on N(0, sigma)."""
  578. def init_(tensor):
  579. return nn.init.normal_(tensor, mean=0.0, std=sigma)
  580. return init_
  581. def scaled_init_method_normal(sigma, num_layers):
  582. """Init method based on N(0, sigma/sqrt(2*num_layers)."""
  583. std = sigma / math.sqrt(2.0 * num_layers)
  584. def init_(tensor):
  585. return nn.init.normal_(tensor, mean=0.0, std=std)
  586. return init_
  587. class GPT3Model(PreTrainedModel):
  588. config_class = GPT3Config
  589. def __init__(self, config):
  590. super().__init__(config)
  591. self.language_model = GPT3TransformerLanguageModel(
  592. config, init_method_normal(config.init_method_std),
  593. scaled_init_method_normal(config.init_method_std,
  594. config.num_hidden_layers))
  595. def word_embeddings_weight(self):
  596. return self.language_model.embedding.word_embeddings.weight
  597. @staticmethod
  598. def build_attention_mask_and_position_ids(tokens):
  599. seq_length = tokens.size(1)
  600. attention_mask = torch.tril(
  601. torch.ones((1, 1, seq_length, seq_length), device=tokens.device))
  602. attention_mask = (attention_mask < 0.5)
  603. position_ids = torch.arange(
  604. seq_length, dtype=torch.long, device=tokens.device)
  605. position_ids = position_ids.unsqueeze(0).expand_as(tokens)
  606. return attention_mask, position_ids
  607. def forward(self,
  608. input_ids,
  609. attention_mask=None,
  610. position_ids=None,
  611. inference_params=None,
  612. labels=None,
  613. **kwargs):
  614. if attention_mask is None and position_ids is None:
  615. attention_mask, position_ids = \
  616. self.build_attention_mask_and_position_ids(input_ids)
  617. lm_output = self.language_model(
  618. input_ids,
  619. position_ids,
  620. attention_mask,
  621. inference_params=inference_params)
  622. logits_parallel = mpu.LinearWithGradAccumulationAndAsyncCommunication.apply(
  623. lm_output, self.word_embeddings_weight(), None, False, True,
  624. self.config.sequence_parallel)
  625. losses = None
  626. if labels is not None:
  627. # [b s] => [s b]
  628. labels = labels.transpose(0, 1).contiguous()
  629. losses = mpu.vocab_parallel_cross_entropy(
  630. logits_parallel.clone().float(), labels)
  631. # [s b] => [b s]
  632. losses = losses.transpose(0, 1).contiguous()
  633. # Gather if needed.
  634. logits = mpu.gather_from_tensor_model_parallel_region(logits_parallel)
  635. # [s b h] => [b s h]
  636. logits = logits.transpose(0, 1).contiguous()
  637. return logits, losses
  638. def modify_logits_for_top_k_filtering(logits, top_k):
  639. """Set the logits for none top-k values to -inf."""
  640. filter_ = logits < torch.topk(logits, top_k)[0][..., -1, None]
  641. logits.masked_fill_(filter_, float('-Inf'))
  642. def modify_logits_for_top_p_filtering(logits, top_p):
  643. """Set the logits for none top-p values to -inf."""
  644. # First sort and calculate cumulative sum of probabilities.
  645. sorted_logits, sorted_indices = torch.sort(logits, descending=True)
  646. cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
  647. # Filteration based on the cumulative sum.
  648. filter_ = cumulative_probs > top_p
  649. # This shift by 1 is weird and I cannot justify it. This existed
  650. # in the original implementation:
  651. # https://github.com/ari-holtzman/degen/blob/master/gen.py
  652. # and I guess it is needed so keeping it for now.
  653. filter_[:, 1:] = filter_[:, :-1].clone()
  654. # Make sure we at least have one token to select from.
  655. filter_[..., 0] = 0
  656. # Fill in the filtered part
  657. filter_ = filter_.scatter(1, sorted_indices, filter_)
  658. logits.masked_fill_(filter_, float('-Inf'))
  659. def sample(logits, top_k=0, top_p=0.0, temperature=1.0, vocab_size=None):
  660. """ Sample and generate a token.
  661. Note: logits has the dimension [b, v] where b is the batch size
  662. and v is the vocabulary size.
  663. If vocab_size is provided, we will make sure the sample that is
  664. generated is in [0, vocab-size). This will avoid out of vocabulary
  665. generations due to padding.
  666. """
  667. # Check logits for consistency.
  668. assert logits.ndim == 2, 'expected the logits to be of [b, v] shape.'
  669. # Greedy is just simple argmax.
  670. if top_k == 1:
  671. assert top_p == 0.0, 'cannot set both greedy and top-p samplings.'
  672. samples = torch.argmax(logits, dim=-1)
  673. # Top-k or top-p sampling.
  674. else:
  675. # Clone so we do not modify the inputs,
  676. logits = logits.clone()
  677. # Apply temperature in place.
  678. if temperature != 1.0:
  679. logits.div_(temperature)
  680. if top_k > 1:
  681. assert top_p == 0.0, 'cannot set both top-k and top-p samplings.'
  682. assert top_k <= logits.size(1), 'top-k is larger than logit size.'
  683. if vocab_size:
  684. assert top_k < vocab_size, 'top-k is larger than vocab size.'
  685. modify_logits_for_top_k_filtering(logits, top_k)
  686. elif top_p > 0.0:
  687. assert top_p <= 1.0, 'top-p should be in (0, 1].'
  688. modify_logits_for_top_p_filtering(logits, top_p)
  689. # After filtering, we need to recalculate the distribution.
  690. probs = logits.softmax(dim=-1)
  691. samples = torch.multinomial(probs, num_samples=1).view(-1)
  692. # If vocab size is provided, make sure the samples are in
  693. # in the range [0, vocab-size).
  694. if vocab_size:
  695. samples = torch.clamp(samples, min=0, max=(vocab_size - 1))
  696. return samples
  697. class InferenceParams:
  698. """Inference parameters that are passed to the main model in order
  699. to efficienly calculate and store the context during inference."""
  700. def __init__(self, max_batch_size, max_sequence_len):
  701. """Note that offsets are set to zero and we always set the
  702. flag to allocate memory. After the first call, make sure to
  703. set this flag to False."""
  704. self.max_sequence_len = max_sequence_len
  705. self.max_batch_size = max_batch_size
  706. self.sequence_len_offset = 0
  707. self.batch_size_offset = 0
  708. self.key_value_memory_dict = {}
  709. def swap_key_value_dict(self, batch_idx):
  710. 'swap between batches'
  711. if len(self.key_value_memory_dict) == 0:
  712. raise ValueError('should not swap when dict in empty')
  713. for layer_number in self.key_value_memory_dict.keys():
  714. inference_key_memory, inference_value_memory = self.key_value_memory_dict[
  715. layer_number]
  716. assert len(batch_idx) == inference_key_memory.shape[
  717. 1] # make sure batch size is the same
  718. new_inference_key_memory = inference_key_memory[:, batch_idx]
  719. new_inference_value_memory = inference_value_memory[:, batch_idx]
  720. self.key_value_memory_dict[layer_number] = (
  721. new_inference_key_memory, new_inference_value_memory)
  722. def split_into_partitions(tensor, num_partitions, partition_dim, stride):
  723. per_partition_size = mpu.utils.divide(
  724. tensor.size(partition_dim), num_partitions)
  725. per_partition_per_stride_size = mpu.utils.divide(per_partition_size,
  726. stride)
  727. partitions_list = torch.split(
  728. tensor, per_partition_per_stride_size, dim=partition_dim)
  729. partitions = []
  730. for i in range(num_partitions):
  731. partition = torch.cat(
  732. partitions_list[i::num_partitions], dim=partition_dim)
  733. partitions.append(partition)
  734. return partitions
  735. def split_state_dict(state_dict: Dict[str, torch.Tensor], model: GPT3Model,
  736. partitions: int) -> Dict[str, torch.Tensor]:
  737. if partitions == 1:
  738. return state_dict
  739. rank: int = mpu.get_tensor_model_parallel_rank()
  740. for name, parameters in model.named_parameters():
  741. if parameters.shape == state_dict[name].shape:
  742. continue
  743. dim = max(parameters.partition_dim, 0)
  744. stride = parameters.partition_stride
  745. state_dict[name] = split_into_partitions(state_dict[name], partitions,
  746. dim, stride)[rank]
  747. return state_dict
  748. class DistributedGPT3(TorchModel, StreamingOutputMixin):
  749. def __init__(self,
  750. model_dir,
  751. rank,
  752. path_load_tag='model',
  753. *args,
  754. megatron_cfg=None,
  755. **kwargs):
  756. super().__init__(model_dir, *args, **kwargs)
  757. init_megatron_util(megatron_cfg, model_dir, rank=rank)
  758. self.config = GPT3Config.from_pretrained(model_dir)
  759. # Build model.
  760. model = GPT3Model(self.config)
  761. for param in model.parameters():
  762. mpu.set_defaults_if_not_set_tensor_model_parallel_attributes(param)
  763. # GPU allocation.
  764. model.cuda(torch.cuda.current_device())
  765. # Fp16 conversion.
  766. if self.config.fp16 or self.config.bf16:
  767. model = Float16Module(model, self.config)
  768. self.dist_model = model
  769. tensor_ws = mpu.get_tensor_model_parallel_world_size()
  770. ckpt_ws = get_args().get('checkpoint_tensor_model_parallel_size', None)
  771. ckpt_ws = tensor_ws if ckpt_ws is None else ckpt_ws
  772. ckpt_rank = mpu.get_tensor_model_parallel_rank() * ckpt_ws // tensor_ws
  773. load_model = pre_load(ckpt_rank, model_dir, tag=path_load_tag)
  774. load_model = split_state_dict(load_model, model, tensor_ws // ckpt_ws)
  775. self.dist_model.load_state_dict(
  776. load_model, strict=kwargs.get('strict', True))
  777. self.inference_params = None
  778. def train(self, mode: bool = True):
  779. if mode:
  780. self.inference_params = None
  781. return super().train(mode)
  782. def forward(self,
  783. tokens,
  784. attention_mask=None,
  785. position_ids=None,
  786. labels=None,
  787. prompts_len=None,
  788. inputs_len=None):
  789. logits, losses = self.dist_model(
  790. tokens,
  791. attention_mask,
  792. position_ids,
  793. inference_params=self.inference_params,
  794. labels=labels)
  795. loss = None
  796. if labels is None:
  797. self.inference_params.sequence_len_offset += tokens.size(1)
  798. else:
  799. loss_mask = torch.ones(
  800. labels.size(), dtype=torch.float, device=tokens.device)
  801. if inputs_len is None:
  802. for i, l in enumerate(prompts_len):
  803. loss_mask[i, l:] = 0
  804. else:
  805. for i, l in enumerate(inputs_len):
  806. loss_mask[i, l - 1:] = 0
  807. for i, l in enumerate(prompts_len):
  808. loss_mask[i, :l - 1] = 0
  809. losses = losses.float()
  810. loss_mask = loss_mask.view(-1).float()
  811. mask_sum = loss_mask.sum()
  812. if mask_sum == 0:
  813. loss = torch.sum(losses.view(-1)).zero_()
  814. else:
  815. loss = torch.sum(losses.view(-1) * loss_mask) / mask_sum
  816. return TextGenerationModelOutput(logits=logits, loss=loss)
  817. def sample(self,
  818. tokens,
  819. prompts_len=None,
  820. use_eod_token_for_early_termination=True,
  821. stop_on_double_eol=False,
  822. stop_on_eol=False,
  823. **kwargs):
  824. top_k = kwargs.pop('top_k', self.config.top_k)
  825. top_p = kwargs.pop('top_p', self.config.top_p)
  826. temperature = kwargs.pop('temperature', self.config.temperature)
  827. max_length = kwargs.pop(
  828. 'max_length',
  829. tokens.size(1) + self.config.tokens_to_generate)
  830. batch_size = tokens.size(0)
  831. lengths = prompts_len
  832. if lengths is None:
  833. lengths = torch.tensor([tokens.size(1)], device=tokens.device)
  834. min_prompt_length = lengths.min().item()
  835. max_sequence_length = min(max_length,
  836. self.config.max_position_embeddings)
  837. # If the context is too big, this happens
  838. if min_prompt_length >= max_sequence_length:
  839. raise ValueError('context length + tokens_to_generate too large')
  840. pad_length = max_sequence_length - tokens.size(1)
  841. if pad_length > 0:
  842. pads = torch.zeros(
  843. batch_size, pad_length, device=tokens.device).long()
  844. tokens = torch.cat((tokens, pads), dim=-1)
  845. # Initialize inference parameters.
  846. self.inference_params = InferenceParams(batch_size,
  847. max_sequence_length)
  848. # Added termination_id to support the case that we want to terminate the
  849. # generation once that id is generated.
  850. termination_id = self.config.eod_id
  851. # Whether we have reached a termination id.
  852. is_generation_done = torch.zeros(
  853. batch_size, dtype=torch.uint8, device=torch.cuda.current_device())
  854. # =============
  855. # Run infernece
  856. # =============
  857. attention_mask, position_ids = \
  858. GPT3Model.build_attention_mask_and_position_ids(tokens)
  859. prev_context_length = 0
  860. for context_length in range(min_prompt_length, max_sequence_length):
  861. # Pick the slice that we need to pass through the network.
  862. tokens2use = tokens[:, prev_context_length:context_length]
  863. positions2use = position_ids[:, prev_context_length:context_length]
  864. attention_mask2use = attention_mask[
  865. ..., prev_context_length:context_length, :context_length]
  866. # logits will be meanigful only in the last pipeline stage.
  867. logits = self(tokens2use, attention_mask2use, positions2use).logits
  868. # Sample.
  869. last_token_logits = logits[:, -1, :]
  870. new_sample = sample(
  871. last_token_logits,
  872. top_k=top_k,
  873. top_p=top_p,
  874. temperature=temperature,
  875. vocab_size=self.config.vocab_size)
  876. # If a prompt length is smaller or equal th current context
  877. # length, it means we have started generating tokens
  878. started = lengths <= context_length
  879. # Update the tokens.
  880. tokens[started, context_length] = new_sample[started]
  881. # streaming output
  882. yield TokenGeneratorOutput(sequences=tokens[:, :(context_length
  883. + 1)])
  884. # Update the context length for the next token generation.
  885. prev_context_length = context_length
  886. # instead tokenization should be in the inference loop so stop sequences can be used
  887. if stop_on_double_eol:
  888. hit_double_eol = (new_sample == 628).byte() & started.byte()
  889. hit_two_eols = (new_sample == 198).byte() & (
  890. tokens[:,
  891. context_length - 1] == 198).byte() & started.byte()
  892. done_token = hit_double_eol | hit_two_eols
  893. elif stop_on_eol:
  894. hit_double_eol = (new_sample == 628).byte() & started.byte()
  895. hit_eol = (new_sample == 198).byte() & started.byte()
  896. done_token = hit_double_eol | hit_eol
  897. else:
  898. done_token = (new_sample == termination_id).byte() & \
  899. started.byte()
  900. is_generation_done = is_generation_done | done_token
  901. done = torch.all(is_generation_done)
  902. if use_eod_token_for_early_termination and done:
  903. break
  904. def beam_search(self, tokens, beam_size=5, num_return_gen=1, **kwargs):
  905. batch_size = tokens.size(0)
  906. assert (batch_size == 1)
  907. prompt_length = kwargs.pop(
  908. 'prompt_length',
  909. torch.tensor([tokens.size(1)], device=tokens.device)).item()
  910. stop_token = self.config.eod_id
  911. pads = torch.ones(
  912. 1, self.config.tokens_to_generate,
  913. device=tokens.device).long() * stop_token
  914. tokens = torch.cat((tokens, pads), dim=-1)
  915. final_sequence_length = tokens.size(1)
  916. final_sequence_length = min(final_sequence_length,
  917. self.config.max_position_embeddings)
  918. # If the context is too big, this happens
  919. if prompt_length >= final_sequence_length:
  920. raise ValueError('context length + tokens_to_generate too large')
  921. # Initialize inference parameters.
  922. self.inference_params = InferenceParams(beam_size,
  923. final_sequence_length)
  924. beam_hyp = BeamHypotheses(beam_size)
  925. done = False
  926. scores = torch.zeros(
  927. beam_size, dtype=torch.float32,
  928. device=torch.cuda.current_device()).unsqueeze(1)
  929. # =============
  930. # Run infernece
  931. # =============
  932. tokens = tokens.repeat(beam_size, 1)
  933. attention_mask, position_ids = \
  934. GPT3Model.build_attention_mask_and_position_ids(tokens)
  935. prev_context_length = 0
  936. for context_length in range(prompt_length, final_sequence_length):
  937. # Pick the slice that we need to pass through the network.
  938. tokens2use = tokens[:, prev_context_length:context_length]
  939. positions2use = position_ids[:, prev_context_length:context_length]
  940. attention_mask2use = attention_mask[
  941. ..., prev_context_length:context_length, :context_length]
  942. # logits will be meanigful only in the last pipeline stage.
  943. logits = self(tokens2use, attention_mask2use, positions2use).logits
  944. vocab_size = logits.size(2)
  945. log_probs = F.log_softmax(logits, dim=2)
  946. new_scores = log_probs[:, -1, :] + scores
  947. if context_length == prompt_length: # if this is the first one
  948. sorted_scores, indices = torch.sort(
  949. new_scores[0, :], descending=True)
  950. else:
  951. sorted_scores, indices = torch.sort(
  952. new_scores.view(-1), descending=True)
  953. best_beam_ids = torch.div(indices[:2 * beam_size],
  954. vocab_size).trunc().long()
  955. best_words = indices[:2 * beam_size] % vocab_size
  956. best_scores = sorted_scores[:2 * beam_size]
  957. next_beams = []
  958. for beam_token_rank, (token_id, beam_score, beam_id) in enumerate(
  959. zip(best_words, best_scores, best_beam_ids)):
  960. if token_id.item() == stop_token:
  961. # if beam_token does not belong to top num_beams tokens, it should not be added
  962. is_beam_token_worse_than_top_num_beams = beam_token_rank >= beam_size
  963. if is_beam_token_worse_than_top_num_beams:
  964. continue
  965. beam_hyp.add(tokens[beam_id].clone(), beam_score,
  966. context_length + 1 - prompt_length)
  967. else:
  968. # add next predicted token since it is not eos_token
  969. next_beams.append((token_id, beam_score, beam_id))
  970. if len(next_beams) == beam_size:
  971. break
  972. if beam_hyp.is_done(best_scores.max().item(),
  973. context_length + 1 - prompt_length):
  974. done = True
  975. break
  976. best_batches = tokens.new([item[2] for item in next_beams])
  977. tokens = tokens[best_batches, :]
  978. tokens[:, context_length] = tokens.new(
  979. [item[0] for item in next_beams])
  980. scores = scores.new([item[1] for item in next_beams]).unsqueeze(1)
  981. # set inference key values to make it consistent with best beam index
  982. self.inference_params.swap_key_value_dict(best_batches)
  983. # Update the context length for the next token generation.
  984. prev_context_length = context_length
  985. # if cannot find stop token, add open beams to hyps
  986. if not done:
  987. for beam_id in range(beam_size):
  988. beam_hyp.add(tokens[beam_id].clone(), scores[beam_id],
  989. context_length + 1 - prompt_length)
  990. # rank based on scores
  991. sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0], reverse=True)
  992. num_return_gen = min(num_return_gen, len(sorted_hyps))
  993. scores = [sorted_hyps[i][0] for i in range(num_return_gen)]
  994. tokens = [sorted_hyps[i][1] for i in range(num_return_gen)]
  995. scores = torch.stack(scores, dim=0)
  996. tokens = torch.stack(tokens, dim=0)
  997. return TokenGeneratorOutput(sequences=tokens, scores=scores)
  998. @torch.no_grad()
  999. def generate(self, tokens, do_sample=True, *args, **kwargs):
  1000. if do_sample:
  1001. last_output = None
  1002. for output in self.sample(tokens, *args, **kwargs):
  1003. last_output = output
  1004. return last_output
  1005. else:
  1006. return self.beam_search(tokens, *args, **kwargs)
  1007. @torch.no_grad()
  1008. def stream_generate(self, tokens, *args, **kwargs):
  1009. return self.sample(tokens, *args, **kwargs)
  1010. def state_dict(self, destination=None, prefix='', keep_vars=False):
  1011. return self.dist_model.state_dict(destination, prefix, keep_vars)
  1012. def load_state_dict(self,
  1013. state_dict: 'OrderedDict[str, torch.Tensor]',
  1014. strict: bool = True):
  1015. return self.dist_model.load_state_dict(state_dict, strict)
  1016. def save_pretrained(self,
  1017. target_folder: Union[str, os.PathLike],
  1018. save_checkpoint_names: Union[str, List[str]] = None,
  1019. save_function: Callable = None,
  1020. config: Optional[dict] = None,
  1021. **kwargs):
  1022. # DistributedPipeline type is different from task name
  1023. config['pipeline']['type'] = 'gpt3-generation'
  1024. config['model'].pop('rank', None)
  1025. config['model'].pop('megatron_cfg', None)
  1026. config['megatron'].pop('rank', None)
  1027. config['megatron'].pop('checkpoint_tensor_model_parallel_size', None)
  1028. tp_size = get_args().tensor_model_parallel_size
  1029. pp_size = get_args().pipeline_model_parallel_size
  1030. config['megatron']['world_size'] = tp_size * pp_size
  1031. return super().save_pretrained(target_folder, save_checkpoint_names,
  1032. save_function, config, **kwargs)
  1033. class BeamHypotheses:
  1034. def __init__(self,
  1035. num_beams: int,
  1036. length_penalty: float = 1.0,
  1037. early_stopping: bool = False):
  1038. """
  1039. Initialize n-best list of hypotheses.
  1040. """
  1041. self.length_penalty = length_penalty
  1042. self.early_stopping = early_stopping
  1043. self.num_beams = num_beams
  1044. self.beams = []
  1045. self.worst_score = 1e9
  1046. def __len__(self):
  1047. """
  1048. Number of hypotheses in the list.
  1049. """
  1050. return len(self.beams)
  1051. def add(self,
  1052. hyp: torch.LongTensor,
  1053. sum_logprobs: float,
  1054. beam_indices: Optional[torch.LongTensor] = None):
  1055. """
  1056. Add a new hypothesis to the list.
  1057. """
  1058. score = sum_logprobs / (hyp.shape[-1]**self.length_penalty)
  1059. if len(self) < self.num_beams or score > self.worst_score:
  1060. self.beams.append((score, hyp, beam_indices))
  1061. if len(self) > self.num_beams:
  1062. sorted_next_scores = sorted([
  1063. (s, idx) for idx, (s, _, _) in enumerate(self.beams)
  1064. ])
  1065. del self.beams[sorted_next_scores[0][1]]
  1066. self.worst_score = sorted_next_scores[1][0]
  1067. else:
  1068. self.worst_score = min(score, self.worst_score)
  1069. def is_done(self, best_sum_logprobs: float, cur_len: int) -> bool:
  1070. """
  1071. If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst
  1072. one in the heap, then we are done with this sentence.
  1073. """
  1074. if len(self) < self.num_beams:
  1075. return False
  1076. elif self.early_stopping:
  1077. return True
  1078. else:
  1079. cur_score = best_sum_logprobs / cur_len**self.length_penalty
  1080. ret = self.worst_score >= cur_score
  1081. return ret