backbone.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355
  1. # Copyright 2021-2022 The Alibaba PAI Team Authors.
  2. # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import math
  16. import os
  17. from typing import Optional, Union
  18. import addict
  19. import torch
  20. from torch import nn
  21. from torch.nn import functional as F
  22. from transformers.modeling_utils import PreTrainedModel
  23. from modelscope.utils.constant import ModelFile
  24. from .configuration import GPTMoEConfig
  25. class GPTMoESelfAttention(nn.Module):
  26. """Parallel self-attention layer abstract class.
  27. Self-attention layer takes input with size [s, b, h]
  28. and returns output of the same size.
  29. """
  30. def __init__(self, config):
  31. super().__init__()
  32. self.hidden_size = config.hidden_size
  33. self.num_attention_heads = config.num_attention_heads
  34. # Per attention head
  35. self.hidden_size_per_attention_head = \
  36. self.hidden_size // self.num_attention_heads
  37. self.query_key_value = nn.Linear(self.hidden_size,
  38. 3 * self.hidden_size)
  39. self.softmax = nn.Softmax(dim=-1)
  40. self.attention_dropout = nn.Dropout(
  41. config.attention_probs_dropout_prob)
  42. # Output.
  43. self.dense = nn.Linear(self.hidden_size, self.hidden_size)
  44. self.output_dropout = nn.Dropout(config.hidden_dropout_prob)
  45. def _transpose_for_scores(self, tensor):
  46. """Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with
  47. size [b, np, s, hn].
  48. """
  49. new_tensor_shape = tensor.size()[:-1] + (
  50. self.num_attention_heads, self.hidden_size_per_attention_head)
  51. tensor = tensor.view(*new_tensor_shape)
  52. return tensor.permute(0, 2, 1, 3)
  53. def _split_tensor_along_last_dim(self,
  54. tensor,
  55. num_partitions,
  56. contiguous_split_chunks=False):
  57. # Get the size and dimension.
  58. last_dim = tensor.dim() - 1
  59. last_dim_size = tensor.size()[last_dim] // num_partitions
  60. # Split.
  61. tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
  62. # Note: torch.split does not create contiguous tensors by default.
  63. if contiguous_split_chunks:
  64. return tuple(chunk.contiguous() for chunk in tensor_list)
  65. return tensor_list
  66. def forward(self, hidden_states, ltor_mask, is_infer=False):
  67. # hidden_states: [b, s, h]
  68. # ltor_mask: [1, 1, s, s]
  69. # Attention heads. [b, s, hp]
  70. tgt_len = hidden_states.size(1)
  71. ltor_mask = torch.reshape(ltor_mask, [1, 1, tgt_len, tgt_len])
  72. mixed_x_layer = self.query_key_value(hidden_states)
  73. (mixed_query_layer, mixed_key_layer, mixed_value_layer) = \
  74. self._split_tensor_along_last_dim(mixed_x_layer, 3)
  75. # Reshape and transpose [b, np, s, hn]
  76. query_layer = self._transpose_for_scores(mixed_query_layer)
  77. key_layer = self._transpose_for_scores(mixed_key_layer)
  78. value_layer = self._transpose_for_scores(mixed_value_layer)
  79. previous_type = value_layer.type()
  80. # Raw attention scores. [b, np, s, s]
  81. attention_scores = torch.matmul(query_layer,
  82. key_layer.transpose(-1, -2))
  83. attention_scores = attention_scores / math.sqrt(
  84. self.hidden_size_per_attention_head)
  85. # Apply the left to right attention mask.
  86. if is_infer:
  87. src_len = key_layer.size(2)
  88. ltor_mask = torch.tril(
  89. torch.ones((1, tgt_len, src_len),
  90. device=hidden_states.device)).view(
  91. 1, 1, tgt_len, src_len).type(previous_type)
  92. converted_mask = 10000.0 * (1.0 - ltor_mask)
  93. attention_scores = (torch.mul(attention_scores, ltor_mask)
  94. - converted_mask).type(previous_type)
  95. # Attention probabilities. [b, np, s, s]
  96. attention_probs = self.softmax(attention_scores)
  97. # This is actually dropping out entire tokens to attend to, which might
  98. # seem a bit unusual, but is taken from the original Transformer paper.
  99. attention_probs = self.attention_dropout(attention_probs)
  100. # Context layer.
  101. # [b, np, s, hn]
  102. context_layer = torch.matmul(attention_probs, value_layer)
  103. # [b, s, np, hn]
  104. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  105. new_context_layer_shape = context_layer.size()[:-2] + (
  106. self.hidden_size, )
  107. # [b, s, hp]
  108. context_layer = context_layer.view(*new_context_layer_shape)
  109. # Output. [b, s, h]
  110. output = self.dense(context_layer)
  111. output = self.output_dropout(output)
  112. return output
  113. class GPTMoEMLP(nn.Module):
  114. """MLP.
  115. MLP will take the input with h hidden state, project it to 4*h
  116. hidden dimension, perform nonlinear transformation, and project the
  117. state back into h hidden dimension.
  118. """
  119. def __init__(self, config):
  120. super().__init__()
  121. hidden_size = config.hidden_size
  122. # Project to 4h.
  123. self.dense_h_to_4h = nn.Linear(hidden_size, 4 * hidden_size)
  124. self.activation_func = F.gelu
  125. # Project back to h.
  126. self.dense_4h_to_h = nn.Linear(4 * hidden_size, hidden_size)
  127. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  128. def forward(self, hidden_states):
  129. # [s, b, 4hp]
  130. intermediate_parallel = self.dense_h_to_4h(hidden_states)
  131. intermediate_parallel = self.activation_func(intermediate_parallel)
  132. # [s, b, h]
  133. output = self.dense_4h_to_h(intermediate_parallel)
  134. output = self.dropout(output)
  135. return output
  136. class GPTMoETransformerLayer(nn.Module):
  137. """A single transformer layer.
  138. Transformer layer takes input with size [s, b, h] and returns an
  139. output of the same size.
  140. """
  141. def __init__(self, config):
  142. super().__init__()
  143. # Layernorm on the input data.
  144. self.input_layernorm = nn.LayerNorm(
  145. config.hidden_size, eps=config.layernorm_epsilon)
  146. # Self attention.
  147. self.attention = GPTMoESelfAttention(config)
  148. # Layernorm on the attention output
  149. self.post_attention_layernorm = nn.LayerNorm(
  150. config.hidden_size, eps=config.layernorm_epsilon)
  151. # MLP
  152. self.mlp = GPTMoEMLP(config)
  153. def forward(self, hidden_states, ltor_mask):
  154. # hidden_states: [b, s, h]
  155. # ltor_mask: [1, 1, s, s]
  156. # Layer norm at the beginning of the transformer layer.
  157. layernorm_output = self.input_layernorm(hidden_states)
  158. # Self attention.
  159. attention_output = self.attention(layernorm_output, ltor_mask)
  160. # Residual connection.
  161. layernorm_input = hidden_states + attention_output
  162. # Layer norm post the self attention.
  163. layernorm_output = self.post_attention_layernorm(layernorm_input)
  164. # MLP.
  165. mlp_output = self.mlp(layernorm_output)
  166. # Second residual connection.
  167. output = layernorm_input + mlp_output
  168. return output
  169. class GPTMoETransformer(nn.Module):
  170. """Transformer class."""
  171. def __init__(self, config):
  172. super().__init__()
  173. self.input_tensor = None
  174. # Number of layers.
  175. self.num_layers = config.num_hidden_layers
  176. self.layers = torch.nn.ModuleList(
  177. [GPTMoETransformerLayer(config) for _ in range(self.num_layers)])
  178. # Final layer norm before output.
  179. self.final_layernorm = nn.LayerNorm(
  180. config.hidden_size, eps=config.layernorm_epsilon)
  181. def _get_layer(self, layer_number):
  182. return self.layers[layer_number]
  183. def forward(self, hidden_states, attention_mask):
  184. # hidden_states: [s, b, h]
  185. for index in range(self.num_layers):
  186. layer = self._get_layer(index)
  187. hidden_states = layer(hidden_states, attention_mask)
  188. # Final layer norm.
  189. hidden_states = self.final_layernorm(hidden_states)
  190. return hidden_states
  191. class GPTMoETransformerLanguageModel(nn.Module):
  192. """Transformer language model.
  193. Arguments:
  194. transformer_hparams: transformer hyperparameters
  195. vocab_size: vocabulary size
  196. max_sequence_length: maximum size of sequence. This
  197. is used for positional embedding
  198. embedding_dropout_prob: dropout probability for embeddings
  199. num_tokentypes: size of the token-type embeddings. 0 value
  200. will ignore this embedding
  201. """
  202. def __init__(self, config):
  203. super().__init__()
  204. # Embeddings.
  205. self.word_embeddings = nn.Embedding(config.vocab_size,
  206. config.hidden_size)
  207. self.position_embeddings = nn.Embedding(config.max_position_embeddings,
  208. config.hidden_size)
  209. self.embedding_dropout = nn.Dropout(config.hidden_dropout_prob)
  210. # Transformer.
  211. self.transformer = GPTMoETransformer(config)
  212. def forward(self, input_ids, attention_mask, position_ids):
  213. words_embeddings = self.word_embeddings(input_ids)
  214. position_embeddings = self.position_embeddings(position_ids)
  215. embeddings = words_embeddings + position_embeddings
  216. transformer_input = self.embedding_dropout(embeddings)
  217. transformer_output = self.transformer(transformer_input,
  218. attention_mask)
  219. logits = F.linear(transformer_output, self.word_embeddings.weight)
  220. return logits
  221. class GPTMoEModel(PreTrainedModel):
  222. config_class = GPTMoEConfig
  223. def _init_weights(self, module):
  224. """Initialize the weights"""
  225. if isinstance(module, nn.Linear):
  226. # Slightly different from the TF version which uses truncated_normal for initialization
  227. # cf https://github.com/pytorch/pytorch/pull/5617
  228. module.weight.data.normal_(
  229. mean=0.0, std=self.config.initializer_range)
  230. if module.bias is not None:
  231. module.bias.data.zero_()
  232. elif isinstance(module, nn.Embedding):
  233. module.weight.data.normal_(
  234. mean=0.0, std=self.config.initializer_range)
  235. if module.padding_idx is not None:
  236. module.weight.data[module.padding_idx].zero_()
  237. elif isinstance(module, nn.LayerNorm):
  238. module.bias.data.zero_()
  239. module.weight.data.fill_(1.0)
  240. def __init__(self, config):
  241. super().__init__(config)
  242. self.language_model = GPTMoETransformerLanguageModel(config)
  243. def forward(self,
  244. input_ids,
  245. attention_mask=None,
  246. position_ids=None,
  247. labels=None,
  248. **kwargs):
  249. seq_length = input_ids.size(1)
  250. attention_mask = torch.tril(
  251. torch.ones((1, 1, seq_length, seq_length),
  252. dtype=torch.long,
  253. device=input_ids.device))
  254. if position_ids is None:
  255. position_ids = torch.arange(
  256. seq_length, dtype=torch.long, device=input_ids.device)
  257. position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
  258. logits = self.language_model(input_ids, attention_mask, position_ids)
  259. loss = None
  260. if labels is not None:
  261. loss_fct = nn.CrossEntropyLoss()
  262. loss = loss_fct(
  263. logits.view(-1, self.config.vocab_size), labels.view(-1))
  264. return addict.Dict(loss=loss, logits=logits)
  265. @classmethod
  266. def from_pretrained(
  267. cls, pretrained_model_name_or_path: Optional[Union[str,
  268. os.PathLike]]):
  269. config = cls.config_class.from_pretrained(
  270. pretrained_model_name_or_path)
  271. model = cls(config)
  272. state_dict_file = os.path.join(pretrained_model_name_or_path,
  273. ModelFile.TORCH_MODEL_BIN_FILE)
  274. state_dict = torch.load(state_dict_file)
  275. if 'state_dict' in state_dict:
  276. state_dict = state_dict['state_dict']
  277. state_dict = {
  278. k.replace('model.language_model', 'language_model'): v
  279. for k, v in state_dict.items()
  280. }
  281. model.load_state_dict(state_dict)
  282. return model
  283. def prepare_inputs_for_generation(self, input_ids, *args, **kwargs):
  284. return {'input_ids': input_ids}