backbone.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429
  1. # Copyright 2021-2022 The Alibaba DAMO NLP Team Authors.
  2. # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import math
  16. import os
  17. from typing import Optional, Union
  18. import addict
  19. import torch
  20. from torch import nn
  21. from torch.nn import functional as F
  22. from transformers.modeling_utils import PreTrainedModel
  23. from modelscope.outputs import TokenGeneratorOutput
  24. from modelscope.utils.constant import ModelFile
  25. from .configuration import GPT3Config
  26. from .distributed_gpt3 import sample
  27. class GPT3SelfAttention(nn.Module):
  28. """Parallel self-attention layer abstract class.
  29. Self-attention layer takes input with size [s, b, h]
  30. and returns output of the same size.
  31. """
  32. def __init__(self, config):
  33. super().__init__()
  34. self.hidden_size = config.hidden_size
  35. self.num_attention_heads = config.num_attention_heads
  36. # Per attention head
  37. self.hidden_size_per_attention_head = \
  38. self.hidden_size // self.num_attention_heads
  39. self.query_key_value = nn.Linear(self.hidden_size,
  40. 3 * self.hidden_size)
  41. self.softmax = nn.Softmax(dim=-1)
  42. self.attention_dropout = nn.Dropout(
  43. config.attention_probs_dropout_prob)
  44. # Output.
  45. self.dense = nn.Linear(self.hidden_size, self.hidden_size)
  46. self.output_dropout = nn.Dropout(config.hidden_dropout_prob)
  47. def _transpose_for_scores(self, tensor):
  48. """Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with
  49. size [b, np, s, hn].
  50. """
  51. new_tensor_shape = tensor.size()[:-1] + (
  52. self.num_attention_heads, self.hidden_size_per_attention_head)
  53. tensor = tensor.view(*new_tensor_shape)
  54. return tensor.permute(0, 2, 1, 3)
  55. def _split_tensor_along_last_dim(self,
  56. tensor,
  57. num_partitions,
  58. contiguous_split_chunks=False):
  59. # Get the size and dimension.
  60. last_dim = tensor.dim() - 1
  61. last_dim_size = tensor.size()[last_dim] // num_partitions
  62. # Split.
  63. tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
  64. # Note: torch.split does not create contiguous tensors by default.
  65. if contiguous_split_chunks:
  66. return tuple(chunk.contiguous() for chunk in tensor_list)
  67. return tensor_list
  68. def forward(self, hidden_states, ltor_mask, is_infer=False):
  69. # hidden_states: [b, s, h]
  70. # ltor_mask: [1, 1, s, s]
  71. # Attention heads. [b, s, hp]
  72. tgt_len = hidden_states.size(1)
  73. ltor_mask = torch.reshape(ltor_mask, [1, 1, tgt_len, tgt_len])
  74. mixed_x_layer = self.query_key_value(hidden_states)
  75. (mixed_query_layer, mixed_key_layer, mixed_value_layer) = \
  76. self._split_tensor_along_last_dim(mixed_x_layer, 3)
  77. # Reshape and transpose [b, np, s, hn]
  78. query_layer = self._transpose_for_scores(mixed_query_layer)
  79. key_layer = self._transpose_for_scores(mixed_key_layer)
  80. value_layer = self._transpose_for_scores(mixed_value_layer)
  81. previous_type = value_layer.type()
  82. # Raw attention scores. [b, np, s, s]
  83. attention_scores = torch.matmul(query_layer,
  84. key_layer.transpose(-1, -2))
  85. attention_scores = attention_scores / math.sqrt(
  86. self.hidden_size_per_attention_head)
  87. # Apply the left to right attention mask.
  88. if is_infer:
  89. src_len = key_layer.size(2)
  90. ltor_mask = torch.tril(
  91. torch.ones((1, tgt_len, src_len),
  92. device=hidden_states.device)).view(
  93. 1, 1, tgt_len, src_len).type(previous_type)
  94. converted_mask = 10000.0 * (1.0 - ltor_mask)
  95. attention_scores = (torch.mul(attention_scores, ltor_mask)
  96. - converted_mask).type(previous_type)
  97. # Attention probabilities. [b, np, s, s]
  98. attention_probs = self.softmax(attention_scores)
  99. # This is actually dropping out entire tokens to attend to, which might
  100. # seem a bit unusual, but is taken from the original Transformer paper.
  101. attention_probs = self.attention_dropout(attention_probs)
  102. # Context layer.
  103. # [b, np, s, hn]
  104. context_layer = torch.matmul(attention_probs, value_layer)
  105. # [b, s, np, hn]
  106. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  107. new_context_layer_shape = context_layer.size()[:-2] + (
  108. self.hidden_size, )
  109. # [b, s, hp]
  110. context_layer = context_layer.view(*new_context_layer_shape)
  111. # Output. [b, s, h]
  112. output = self.dense(context_layer)
  113. output = self.output_dropout(output)
  114. return output
  115. class GPT3MLP(nn.Module):
  116. """MLP.
  117. MLP will take the input with h hidden state, project it to 4*h
  118. hidden dimension, perform nonlinear transformation, and project the
  119. state back into h hidden dimension.
  120. """
  121. def __init__(self, config):
  122. super().__init__()
  123. hidden_size = config.hidden_size
  124. # Project to 4h.
  125. self.dense_h_to_4h = nn.Linear(hidden_size, 4 * hidden_size)
  126. self.activation_func = F.gelu
  127. # Project back to h.
  128. self.dense_4h_to_h = nn.Linear(4 * hidden_size, hidden_size)
  129. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  130. def forward(self, hidden_states):
  131. # [s, b, 4hp]
  132. intermediate_parallel = self.dense_h_to_4h(hidden_states)
  133. intermediate_parallel = self.activation_func(intermediate_parallel)
  134. # [s, b, h]
  135. output = self.dense_4h_to_h(intermediate_parallel)
  136. output = self.dropout(output)
  137. return output
  138. class GPT3TransformerLayer(nn.Module):
  139. """A single transformer layer.
  140. Transformer layer takes input with size [s, b, h] and returns an
  141. output of the same size.
  142. """
  143. def __init__(self, config):
  144. super().__init__()
  145. # Layernorm on the input data.
  146. self.input_layernorm = nn.LayerNorm(
  147. config.hidden_size, eps=config.layernorm_epsilon)
  148. # Self attention.
  149. self.attention = GPT3SelfAttention(config)
  150. # Layernorm on the attention output
  151. self.post_attention_layernorm = nn.LayerNorm(
  152. config.hidden_size, eps=config.layernorm_epsilon)
  153. # MLP
  154. self.mlp = GPT3MLP(config)
  155. def forward(self, hidden_states, ltor_mask):
  156. # hidden_states: [b, s, h]
  157. # ltor_mask: [1, 1, s, s]
  158. # Layer norm at the beginning of the transformer layer.
  159. layernorm_output = self.input_layernorm(hidden_states)
  160. # Self attention.
  161. attention_output = self.attention(layernorm_output, ltor_mask)
  162. # Residual connection.
  163. layernorm_input = hidden_states + attention_output
  164. # Layer norm post the self attention.
  165. layernorm_output = self.post_attention_layernorm(layernorm_input)
  166. # MLP.
  167. mlp_output = self.mlp(layernorm_output)
  168. # Second residual connection.
  169. output = layernorm_input + mlp_output
  170. return output
  171. class GPT3Transformer(nn.Module):
  172. """Transformer class."""
  173. def __init__(self, config):
  174. super().__init__()
  175. self.input_tensor = None
  176. # Number of layers.
  177. self.num_layers = config.num_hidden_layers
  178. self.layers = torch.nn.ModuleList(
  179. [GPT3TransformerLayer(config) for _ in range(self.num_layers)])
  180. # Final layer norm before output.
  181. self.final_layernorm = nn.LayerNorm(
  182. config.hidden_size, eps=config.layernorm_epsilon)
  183. def _get_layer(self, layer_number):
  184. return self.layers[layer_number]
  185. def forward(self, hidden_states, attention_mask):
  186. # hidden_states: [s, b, h]
  187. for index in range(self.num_layers):
  188. layer = self._get_layer(index)
  189. hidden_states = layer(hidden_states, attention_mask)
  190. # Final layer norm.
  191. hidden_states = self.final_layernorm(hidden_states)
  192. return hidden_states
  193. class GPT3TransformerLanguageModel(nn.Module):
  194. """Transformer language model.
  195. Arguments:
  196. transformer_hparams: transformer hyperparameters
  197. vocab_size: vocabulary size
  198. max_sequence_length: maximum size of sequence. This
  199. is used for positional embedding
  200. embedding_dropout_prob: dropout probability for embeddings
  201. num_tokentypes: size of the token-type embeddings. 0 value
  202. will ignore this embedding
  203. """
  204. def __init__(self, config):
  205. super().__init__()
  206. # Embeddings.
  207. self.word_embeddings = nn.Embedding(config.vocab_size,
  208. config.hidden_size)
  209. self.position_embeddings = nn.Embedding(config.max_position_embeddings,
  210. config.hidden_size)
  211. self.embedding_dropout = nn.Dropout(config.hidden_dropout_prob)
  212. # Transformer.
  213. self.transformer = GPT3Transformer(config)
  214. def forward(self, input_ids, attention_mask, position_ids):
  215. words_embeddings = self.word_embeddings(input_ids)
  216. position_embeddings = self.position_embeddings(position_ids)
  217. embeddings = words_embeddings + position_embeddings
  218. transformer_input = self.embedding_dropout(embeddings)
  219. transformer_output = self.transformer(transformer_input,
  220. attention_mask)
  221. logits = F.linear(transformer_output, self.word_embeddings.weight)
  222. return logits
  223. class GPT3Model(PreTrainedModel):
  224. config_class = GPT3Config
  225. def _init_weights(self, module):
  226. """Initialize the weights"""
  227. if isinstance(module, nn.Linear):
  228. # Slightly different from the TF version which uses truncated_normal for initialization
  229. # cf https://github.com/pytorch/pytorch/pull/5617
  230. module.weight.data.normal_(
  231. mean=0.0, std=self.config.initializer_range)
  232. if module.bias is not None:
  233. module.bias.data.zero_()
  234. elif isinstance(module, nn.Embedding):
  235. module.weight.data.normal_(
  236. mean=0.0, std=self.config.initializer_range)
  237. if module.padding_idx is not None:
  238. module.weight.data[module.padding_idx].zero_()
  239. elif isinstance(module, nn.LayerNorm):
  240. module.bias.data.zero_()
  241. module.weight.data.fill_(1.0)
  242. def __init__(self, config):
  243. super().__init__(config)
  244. self.language_model = GPT3TransformerLanguageModel(config)
  245. def forward(self,
  246. input_ids,
  247. attention_mask=None,
  248. position_ids=None,
  249. labels=None,
  250. **kwargs):
  251. seq_length = input_ids.size(1)
  252. attention_mask = torch.tril(
  253. torch.ones((1, 1, seq_length, seq_length),
  254. dtype=torch.long,
  255. device=input_ids.device))
  256. if position_ids is None:
  257. position_ids = torch.arange(
  258. seq_length, dtype=torch.long, device=input_ids.device)
  259. position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
  260. logits = self.language_model(input_ids, attention_mask, position_ids)
  261. loss = None
  262. if labels is not None:
  263. loss_fct = nn.CrossEntropyLoss()
  264. loss = loss_fct(
  265. logits.view(-1, self.config.vocab_size), labels.view(-1))
  266. return addict.Dict(loss=loss, logits=logits)
  267. @classmethod
  268. def from_pretrained(
  269. cls, pretrained_model_name_or_path: Optional[Union[str,
  270. os.PathLike]]):
  271. config = cls.config_class.from_pretrained(
  272. pretrained_model_name_or_path)
  273. model = cls(config)
  274. state_dict_file = os.path.join(pretrained_model_name_or_path,
  275. ModelFile.TORCH_MODEL_BIN_FILE)
  276. state_dict = torch.load(state_dict_file)
  277. if 'state_dict' in state_dict:
  278. state_dict = state_dict['state_dict']
  279. state_dict = {
  280. k.replace('model.language_model', 'language_model'): v
  281. for k, v in state_dict.items()
  282. }
  283. model.load_state_dict(state_dict)
  284. return model
  285. def streaming_generate(self, tokens, temperature=1.0, **kwargs):
  286. top_k = kwargs.pop('top_k', self.config.top_k)
  287. top_p = kwargs.pop('top_p', self.config.top_p)
  288. max_length = kwargs.pop('max_length', tokens.size(1) + 100)
  289. batch_size = tokens.size(0)
  290. lengths = kwargs.pop(
  291. 'prompt_length',
  292. torch.tensor([tokens.size(1)], device=tokens.device))
  293. min_prompt_length = lengths.min().item()
  294. max_sequence_length = min(max_length,
  295. self.config.max_position_embeddings)
  296. # If the context is too big, this happens
  297. if min_prompt_length >= max_sequence_length:
  298. raise ValueError('context length too large')
  299. pad_length = max_sequence_length - tokens.size(1)
  300. if pad_length > 0:
  301. pads = torch.zeros(
  302. batch_size, pad_length, device=tokens.device).long()
  303. tokens = torch.cat((tokens, pads), dim=-1)
  304. # Added termination_id to support the case that we want to terminate the
  305. # generation once that id is generated.
  306. termination_id = self.config.eod_id
  307. # Whether we have reached a termination id.
  308. is_generation_done = torch.zeros(
  309. batch_size, dtype=torch.uint8, device=tokens.device)
  310. with torch.no_grad():
  311. for context_length in range(min_prompt_length,
  312. max_sequence_length):
  313. # Pick the slice that we need to pass through the network.
  314. tokens2use = tokens[:, :context_length]
  315. # logits will be meanigful only in the last pipeline stage.
  316. logits = self(tokens2use).logits
  317. # Sample.
  318. last_token_logits = logits[:, -1, :]
  319. new_sample = sample(
  320. last_token_logits,
  321. top_k=top_k,
  322. top_p=top_p,
  323. temperature=temperature,
  324. vocab_size=self.config.vocab_size)
  325. # If a prompt length is smaller or equal th current context
  326. # length, it means we have started generating tokens
  327. started = lengths <= context_length
  328. # Update the tokens.
  329. tokens[started, context_length] = new_sample[started]
  330. yield TokenGeneratorOutput(sequences=tokens[:, :(context_length
  331. + 1)])
  332. done_token = (new_sample == termination_id).byte() & \
  333. started.byte()
  334. is_generation_done = is_generation_done | done_token
  335. done = torch.all(is_generation_done)
  336. if done:
  337. break
  338. def generate(self, tokens, temperature=1.0, **kwargs):
  339. last_output = None
  340. for output in self.streaming_generate(tokens, temperature, **kwargs):
  341. last_output = output
  342. return last_output