backbone.py 40 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009
  1. # Copyright 2021-2022 The Alibaba DAMO NLP Team Authors.
  2. # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
  3. # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. from __future__ import (absolute_import, division, print_function,
  17. unicode_literals)
  18. import logging
  19. import math
  20. import torch
  21. import torch.nn.functional as F
  22. from megatron_util import mpu
  23. from torch import nn
  24. from modelscope.utils.nlp.distributed import (normal_init_method,
  25. scaled_init_method)
  26. from .configuration import PlugNLGConfig, PlugNLUConfig
  27. logger = logging.getLogger(__name__)
  28. def gelu(x):
  29. """Implementation of the gelu activation function.
  30. For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
  31. 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
  32. """
  33. return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
  34. def swish(x):
  35. return x * torch.sigmoid(x)
  36. ACT2FN = {'gelu': gelu, 'relu': torch.nn.functional.relu, 'swish': swish}
  37. class BertLayerNorm(nn.Module):
  38. def __init__(self, hidden_size, eps=1e-12):
  39. """Construct a layernorm module in the TF style (epsilon inside the square root).
  40. """
  41. super(BertLayerNorm, self).__init__()
  42. self.weight = nn.Parameter(torch.ones(hidden_size))
  43. self.bias = nn.Parameter(torch.zeros(hidden_size))
  44. self.variance_epsilon = eps
  45. def forward(self, x):
  46. u = x.mean(-1, keepdim=True)
  47. s = (x - u).pow(2).mean(-1, keepdim=True)
  48. x = (x - u) / torch.sqrt(s + self.variance_epsilon)
  49. return self.weight * x + self.bias
  50. class BertEmbeddings(nn.Module):
  51. """Construct the embeddings from word, position and token_type embeddings.
  52. """
  53. def __init__(self, config):
  54. super(BertEmbeddings, self).__init__()
  55. self.word_embeddings = mpu.VocabParallelEmbedding(
  56. config.vocab_size,
  57. config.hidden_size,
  58. init_method=normal_init_method(
  59. mean=0.0, std=config.initializer_range))
  60. self.position_embeddings = nn.Embedding(config.max_position_embeddings,
  61. config.hidden_size)
  62. self.token_type_embeddings = nn.Embedding(config.type_vocab_size,
  63. config.hidden_size)
  64. # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
  65. # any TensorFlow checkpoint file
  66. self.fp32_layernorm = config.fp32_layernorm
  67. self.fp32_embedding = config.fp32_embedding
  68. self.fp32_tokentypes = config.fp32_tokentypes
  69. self.LayerNorm = BertLayerNorm(
  70. config.hidden_size, eps=config.layernorm_epsilon)
  71. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  72. def forward(self, input_ids, token_type_ids=None, position_ids=None):
  73. seq_length = input_ids.size(1)
  74. if position_ids is None:
  75. position_ids = torch.arange(
  76. seq_length, dtype=torch.long, device=input_ids.device)
  77. position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
  78. if token_type_ids is None:
  79. token_type_ids = torch.zeros_like(input_ids)
  80. words_embeddings = self.word_embeddings(input_ids)
  81. position_embeddings = self.position_embeddings(position_ids)
  82. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  83. if not self.fp32_tokentypes:
  84. embeddings = words_embeddings + position_embeddings + token_type_embeddings
  85. if self.fp32_embedding and not self.fp32_layernorm:
  86. embeddings = embeddings.half()
  87. previous_type = embeddings.type()
  88. if self.fp32_layernorm:
  89. embeddings = embeddings.float()
  90. embeddings = self.LayerNorm(embeddings)
  91. if self.fp32_layernorm:
  92. if self.fp32_embedding:
  93. embeddings = embeddings.half()
  94. else:
  95. embeddings = embeddings.type(previous_type)
  96. else:
  97. embeddings = words_embeddings.float() + position_embeddings.float(
  98. ) + token_type_embeddings.float()
  99. if self.fp32_tokentypes and not self.fp32_layernorm:
  100. embeddings = embeddings.half()
  101. previous_type = embeddings.type()
  102. if self.fp32_layernorm:
  103. embeddings = embeddings.float()
  104. embeddings = self.LayerNorm(embeddings)
  105. if self.fp32_layernorm:
  106. if self.fp32_tokentypes:
  107. embeddings = embeddings.half()
  108. else:
  109. embeddings = embeddings.type(previous_type)
  110. embeddings = self.dropout(embeddings)
  111. return embeddings
  112. class BertSelfOutput(nn.Module):
  113. def __init__(self, config):
  114. super(BertSelfOutput, self).__init__()
  115. if hasattr(config, 'deep_init') and config.deep_init:
  116. init_method = scaled_init_method(
  117. mean=0.0,
  118. std=config.initializer_range,
  119. num_layers=config.num_hidden_layers)
  120. else:
  121. init_method = normal_init_method(
  122. mean=0.0, std=config.initializer_range)
  123. self.dense = mpu.RowParallelLinear(
  124. input_size=config.hidden_size,
  125. output_size=config.hidden_size,
  126. bias=True,
  127. input_is_parallel=True,
  128. stride=1,
  129. init_method=init_method)
  130. self.fp32_layernorm = config.fp32_layernorm
  131. if not config.pre_ln:
  132. self.LayerNorm = BertLayerNorm(
  133. config.hidden_size, eps=config.layernorm_epsilon)
  134. else:
  135. self.LayerNorm = None
  136. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  137. def forward(
  138. self,
  139. hidden_states,
  140. input_tensor,
  141. ):
  142. hidden_states = self.dense(hidden_states)
  143. hidden_states = self.dropout(hidden_states)
  144. ln_input = hidden_states + input_tensor
  145. if self.LayerNorm is not None:
  146. previous_type = ln_input.type()
  147. if self.fp32_layernorm:
  148. ln_input = ln_input.float()
  149. hidden_states = self.LayerNorm(ln_input)
  150. if self.fp32_layernorm:
  151. hidden_states = hidden_states.type(previous_type)
  152. else:
  153. hidden_states = ln_input
  154. return hidden_states
  155. class BertAttention(nn.Module):
  156. def __init__(self, config):
  157. super(BertAttention, self).__init__()
  158. self.fp32_layernorm = config.fp32_layernorm
  159. if config.pre_ln:
  160. self.LayerNorm = BertLayerNorm(
  161. config.hidden_size, eps=config.layernorm_epsilon)
  162. else:
  163. self.LayerNorm = None
  164. self.self = mpu.BertParallelSelfAttention(
  165. hidden_size=config.hidden_size,
  166. num_attention_heads=config.num_attention_heads,
  167. dropout_prob=config.attention_probs_dropout_prob,
  168. output_parallel=True,
  169. init_method=normal_init_method(
  170. mean=0.0, std=config.initializer_range),
  171. separate=config.attn_separate)
  172. self.output = BertSelfOutput(config)
  173. def forward(
  174. self,
  175. input_tensor,
  176. attention_mask,
  177. ):
  178. if self.LayerNorm is not None:
  179. ln_input = input_tensor
  180. previous_type = input_tensor.type()
  181. if self.fp32_layernorm:
  182. ln_input = input_tensor.float()
  183. ln_output = self.LayerNorm(ln_input)
  184. if self.fp32_layernorm:
  185. ln_output = ln_output.type(previous_type)
  186. self_output = self.self(
  187. ln_output,
  188. attention_mask,
  189. )
  190. else:
  191. self_output = self.self(
  192. input_tensor,
  193. attention_mask,
  194. )
  195. attention_output = self.output(
  196. self_output,
  197. input_tensor,
  198. )
  199. return attention_output
  200. class BertIntermediate(nn.Module):
  201. def __init__(self, config):
  202. super(BertIntermediate, self).__init__()
  203. self.dense = mpu.ColumnParallelLinear(
  204. input_size=config.hidden_size,
  205. output_size=config.intermediate_size,
  206. bias=True,
  207. gather_output=False,
  208. stride=1,
  209. init_method=normal_init_method(
  210. mean=0.0, std=config.initializer_range))
  211. self.intermediate_act_fn = ACT2FN[config.hidden_act] \
  212. if isinstance(config.hidden_act, str) else config.hidden_act
  213. def forward(
  214. self,
  215. hidden_states,
  216. ):
  217. hidden_states = self.dense(hidden_states)
  218. hidden_states = self.intermediate_act_fn(hidden_states)
  219. return hidden_states
  220. class BertOutput(nn.Module):
  221. def __init__(self, config):
  222. super(BertOutput, self).__init__()
  223. if hasattr(config, 'deep_init') and config.deep_init:
  224. init_method = scaled_init_method(
  225. mean=0.0,
  226. std=config.initializer_range,
  227. num_layers=config.num_hidden_layers)
  228. else:
  229. init_method = normal_init_method(
  230. mean=0.0, std=config.initializer_range)
  231. self.dense = mpu.RowParallelLinear(
  232. input_size=config.intermediate_size,
  233. output_size=config.hidden_size,
  234. bias=True,
  235. input_is_parallel=True,
  236. stride=1,
  237. init_method=init_method)
  238. self.fp32_layernorm = config.fp32_layernorm
  239. if not config.pre_ln:
  240. self.LayerNorm = BertLayerNorm(
  241. config.hidden_size, eps=config.layernorm_epsilon)
  242. else:
  243. self.LayerNorm = None
  244. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  245. def forward(
  246. self,
  247. hidden_states,
  248. input_tensor,
  249. ):
  250. hidden_states = self.dense(hidden_states)
  251. hidden_states = self.dropout(hidden_states)
  252. ln_input = hidden_states + input_tensor
  253. if self.LayerNorm is not None:
  254. previous_type = ln_input.type()
  255. if self.fp32_layernorm:
  256. ln_input = ln_input.float()
  257. hidden_states = self.LayerNorm(ln_input)
  258. if self.fp32_layernorm:
  259. hidden_states = hidden_states.type(previous_type)
  260. else:
  261. hidden_states = ln_input
  262. return hidden_states
  263. class BertLayer(nn.Module):
  264. def __init__(self, config):
  265. super(BertLayer, self).__init__()
  266. self.attention = BertAttention(config)
  267. self.intermediate = BertIntermediate(config)
  268. self.output = BertOutput(config)
  269. self.fp32_layernorm = config.fp32_layernorm
  270. if config.pre_ln:
  271. self.LayerNorm = BertLayerNorm(
  272. config.hidden_size, eps=config.layernorm_epsilon)
  273. else:
  274. self.LayerNorm = None
  275. def forward(self, hidden_states, attention_mask):
  276. attention_output = self.attention(hidden_states, attention_mask)
  277. if self.LayerNorm is not None:
  278. ln_input = attention_output
  279. previous_type = attention_output.type()
  280. if self.fp32_layernorm:
  281. ln_input = attention_output.float()
  282. ln_output = self.LayerNorm(ln_input)
  283. if self.fp32_layernorm:
  284. ln_output = ln_output.type(previous_type)
  285. intermediate_output = self.intermediate(ln_output)
  286. else:
  287. intermediate_output = self.intermediate(attention_output)
  288. layer_output = self.output(intermediate_output, attention_output)
  289. return layer_output
  290. class BertEncoder(nn.Module):
  291. def __init__(self, config):
  292. super(BertEncoder, self).__init__()
  293. self.layer = nn.ModuleList(
  294. [BertLayer(config) for _ in range(config.num_hidden_layers)])
  295. self.fp32_layernorm = config.fp32_layernorm
  296. if config.pre_ln:
  297. self.LayerNorm = BertLayerNorm(
  298. config.hidden_size, eps=config.layernorm_epsilon)
  299. else:
  300. self.LayerNorm = None
  301. def forward(
  302. self,
  303. hidden_states,
  304. attention_mask,
  305. output_all_encoded_layers=True,
  306. checkpoint_activations=False,
  307. detach_index=-1,
  308. ):
  309. all_encoder_layers = []
  310. def custom(start, end):
  311. def custom_forward(*inputs):
  312. layers = self.layer[start:end]
  313. x_ = inputs[0]
  314. for layer in layers:
  315. x_ = layer(x_, inputs[1])
  316. return x_
  317. return custom_forward
  318. if checkpoint_activations:
  319. layer_idx = 0
  320. num_layers = len(self.layer)
  321. chunk_length = 1
  322. while layer_idx < num_layers:
  323. hidden_states = mpu.checkpoint(
  324. custom(layer_idx, layer_idx + chunk_length), hidden_states,
  325. attention_mask * 1)
  326. if detach_index == layer_idx:
  327. hidden_states.detach_()
  328. layer_idx += chunk_length
  329. # decoder layers
  330. else:
  331. for i, layer_module in enumerate(self.layer):
  332. hidden_states = layer_module(hidden_states, attention_mask)
  333. if detach_index == i:
  334. hidden_states.detach_()
  335. if i == len(self.layer) - 1 and self.LayerNorm is not None:
  336. previous_type = hidden_states.type()
  337. if self.fp32_layernorm:
  338. hidden_states = hidden_states.float()
  339. hidden_states = self.LayerNorm(hidden_states)
  340. if self.fp32_layernorm:
  341. hidden_states = hidden_states.type(previous_type)
  342. if output_all_encoded_layers:
  343. all_encoder_layers.append(hidden_states)
  344. if not output_all_encoded_layers or checkpoint_activations:
  345. if self.LayerNorm is not None:
  346. previous_type = hidden_states.type()
  347. if self.fp32_layernorm:
  348. hidden_states = hidden_states.float()
  349. hidden_states = self.LayerNorm(hidden_states)
  350. if self.fp32_layernorm:
  351. hidden_states = hidden_states.type(previous_type)
  352. all_encoder_layers.append(hidden_states)
  353. return all_encoder_layers
  354. class BertPooler(nn.Module):
  355. def __init__(self, config):
  356. super(BertPooler, self).__init__()
  357. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  358. self.activation = nn.Tanh()
  359. def forward(self, hidden_states):
  360. # We "pool" the model by simply taking the hidden state corresponding
  361. # to the first token.
  362. first_token_tensor = hidden_states[:, 0]
  363. pooled_output = self.dense(first_token_tensor)
  364. pooled_output = self.activation(pooled_output)
  365. return pooled_output
  366. class BertPredictionHeadTransform(nn.Module):
  367. def __init__(self, config):
  368. super(BertPredictionHeadTransform, self).__init__()
  369. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  370. self.transform_act_fn = ACT2FN[config.hidden_act] \
  371. if isinstance(config.hidden_act, str) else config.hidden_act
  372. self.LayerNorm = BertLayerNorm(
  373. config.hidden_size, eps=config.layernorm_epsilon)
  374. self.fp32_layernorm = config.fp32_layernorm
  375. def forward(self, hidden_states):
  376. hidden_states = self.dense(hidden_states)
  377. hidden_states = self.transform_act_fn(hidden_states)
  378. previous_type = hidden_states.type()
  379. if self.fp32_layernorm:
  380. hidden_states = hidden_states.float()
  381. hidden_states = self.LayerNorm(hidden_states)
  382. if self.fp32_layernorm:
  383. hidden_states = hidden_states.type(previous_type)
  384. return hidden_states
  385. class BertLMPredictionHead(nn.Module):
  386. def __init__(self, config, bert_model_embedding_weights):
  387. super(BertLMPredictionHead, self).__init__()
  388. self.transform = BertPredictionHeadTransform(config)
  389. # The output weights are the same as the input embeddings, but there is
  390. # an output-only bias for each token.
  391. self.decoder_weight = bert_model_embedding_weights
  392. self.bias = nn.Parameter(
  393. torch.zeros(bert_model_embedding_weights.size(0)))
  394. self.bias.model_parallel = True
  395. self.fp32_embedding = config.fp32_embedding
  396. self.fp32_layernorm = config.fp32_layernorm
  397. def convert_to_type(tensor):
  398. if self.fp32_embedding:
  399. return tensor.half()
  400. else:
  401. return tensor
  402. self.type_converter = convert_to_type
  403. self.converted = False
  404. def forward(self, hidden_states):
  405. if not self.converted:
  406. self.converted = True
  407. if self.fp32_embedding:
  408. self.transform.half()
  409. if self.fp32_layernorm:
  410. self.transform.LayerNorm.float()
  411. hidden_states = self.transform(self.type_converter(hidden_states))
  412. hidden_states = mpu.copy_to_model_parallel_region(hidden_states)
  413. hidden_states = F.linear(
  414. self.type_converter(hidden_states),
  415. self.type_converter(self.decoder_weight),
  416. self.type_converter(self.bias))
  417. return hidden_states
  418. class BertPreTrainingHeads(nn.Module):
  419. def __init__(self, config, bert_model_embedding_weights):
  420. super(BertPreTrainingHeads, self).__init__()
  421. self.predictions = BertLMPredictionHead(config,
  422. bert_model_embedding_weights)
  423. self.seq_relationship = nn.Linear(config.hidden_size, 3)
  424. def forward(self, sequence_output, pooled_output):
  425. prediction_scores = self.predictions(sequence_output)
  426. for p in self.seq_relationship.parameters():
  427. if p is None:
  428. continue
  429. pooled_output = pooled_output.type_as(p)
  430. seq_relationship_score = self.seq_relationship(pooled_output)
  431. return prediction_scores, seq_relationship_score
  432. class PreTrainedBertModel(nn.Module):
  433. """ An abstract class to handle weights initialization and
  434. a simple interface for downloading and loading pretrained models.
  435. """
  436. def __init__(self, config, *inputs, **kwargs):
  437. super(PreTrainedBertModel, self).__init__()
  438. if not isinstance(config, PlugNLUConfig) and not isinstance(
  439. config, PlugNLGConfig):
  440. raise ValueError(
  441. 'Parameter config in `{}(config)` should be an instance of class `BertConfig`. '
  442. 'To create a model from a Google pretrained model use '
  443. '`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`'.format(
  444. self.__class__.__name__, self.__class__.__name__))
  445. self.config = config
  446. def init_bert_weights(self, module):
  447. """ Initialize the weights.
  448. """
  449. if isinstance(module, (nn.Linear, nn.Embedding)):
  450. # Slightly different from the TF version which uses truncated_normal for initialization
  451. # cf https://github.com/pytorch/pytorch/pull/5617
  452. module.weight.data.normal_(
  453. mean=0.0, std=self.config.initializer_range)
  454. elif isinstance(module, BertLayerNorm):
  455. module.bias.data.zero_()
  456. module.weight.data.fill_(1.0)
  457. if isinstance(module, nn.Linear) and module.bias is not None:
  458. module.bias.data.zero_()
  459. class BertModel(PreTrainedBertModel):
  460. """BERT model ("Bidirectional Embedding Representations from a Transformer").
  461. Params:
  462. config: a BertConfig class instance with the configuration to build a new model
  463. Inputs:
  464. `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
  465. with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
  466. `extract_features.py`, `run_classifier.py` and `run_squad.py`)
  467. `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
  468. types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
  469. a `sentence B` token (see BERT paper for more details).
  470. `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
  471. selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
  472. input sequence length in the current batch. It's the mask that we typically use for attention when
  473. a batch has varying length sentences.
  474. `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as
  475. described below. Default: `True`.
  476. Outputs: Tuple of (encoded_layers, pooled_output)
  477. `encoded_layers`: controlled by `output_all_encoded_layers` argument:
  478. - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end
  479. of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each
  480. encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
  481. - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
  482. to the last attention block of shape [batch_size, sequence_length, hidden_size],
  483. `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
  484. classifier pretrained on top of the hidden state associated to the first character of the
  485. input (`CLF`) to train on the Next-Sentence task (see BERT's paper).
  486. Examples:
  487. >>> # Already been converted into WordPiece token ids
  488. >>> input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
  489. >>> input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
  490. >>> token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
  491. >>> config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
  492. >>> num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
  493. >>> model = modeling.BertModel(config=config)
  494. >>> all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
  495. """
  496. def __init__(self, config):
  497. super(BertModel, self).__init__(config)
  498. self.embeddings = BertEmbeddings(config)
  499. self.encoder = BertEncoder(config)
  500. self.pooler = BertPooler(config)
  501. self.apply(self.init_bert_weights)
  502. def forward(
  503. self,
  504. input_ids,
  505. token_type_ids=None,
  506. attention_mask=None,
  507. output_all_encoded_layers=True,
  508. checkpoint_activations=False,
  509. detach_index=-1,
  510. ):
  511. if attention_mask is None:
  512. attention_mask = torch.ones_like(input_ids)
  513. if token_type_ids is None:
  514. token_type_ids = torch.zeros_like(input_ids)
  515. # We create a 3D attention mask from a 2D tensor mask.
  516. # Sizes are [batch_size, 1, 1, to_seq_length]
  517. # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
  518. # this attention mask is more simple than the triangular masking of causal attention
  519. # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
  520. extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
  521. # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
  522. # masked positions, this operation will create a tensor which is 0.0 for
  523. # positions we want to attend and -10000.0 for masked positions.
  524. # Since we are adding it to the raw scores before the softmax, this is
  525. # effectively the same as removing these entirely.
  526. extended_attention_mask = extended_attention_mask.to(
  527. dtype=next(self.encoder.parameters()).dtype) # fp16 compatibility
  528. extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
  529. embedding_output = self.embeddings(input_ids, token_type_ids)
  530. encoded_layers = self.encoder(
  531. embedding_output,
  532. extended_attention_mask,
  533. output_all_encoded_layers=output_all_encoded_layers,
  534. checkpoint_activations=checkpoint_activations,
  535. detach_index=detach_index)
  536. sequence_output = encoded_layers[-1]
  537. for p in self.pooler.parameters():
  538. if p is None:
  539. continue
  540. sequence_output = sequence_output.type_as(p)
  541. break
  542. pooled_output = sequence_output[:, 0]
  543. if not output_all_encoded_layers or checkpoint_activations:
  544. encoded_layers = encoded_layers[-1]
  545. return encoded_layers, pooled_output
  546. class DecodeLayer(nn.Module):
  547. def __init__(self, config):
  548. super(DecodeLayer, self).__init__()
  549. init_method = normal_init_method(
  550. mean=0.0, std=config.initializer_range)
  551. output_layer_init_method = scaled_init_method(
  552. mean=0.0,
  553. std=config.initializer_range,
  554. num_layers=config.num_hidden_layers)
  555. self.attention = mpu.GPT2ParallelSelfAttention(
  556. hidden_size=config.hidden_size,
  557. num_attention_heads=config.num_attention_heads,
  558. attention_dropout_prob=config.attention_probs_dropout_prob,
  559. output_dropout_prob=config.hidden_dropout_prob,
  560. init_method=init_method,
  561. output_layer_init_method=output_layer_init_method,
  562. )
  563. self.cross_attention = mpu.PalmParallelCrossAttention(
  564. hidden_size=config.hidden_size,
  565. num_attention_heads=config.num_attention_heads,
  566. attention_dropout_prob=config.attention_probs_dropout_prob,
  567. output_dropout_prob=config.hidden_dropout_prob,
  568. init_method=init_method,
  569. attn_separate=False,
  570. output_layer_init_method=output_layer_init_method,
  571. )
  572. self.input_layernorm = BertLayerNorm(
  573. config.hidden_size, eps=config.layernorm_epsilon)
  574. self.post_attention_layernorm = BertLayerNorm(
  575. config.hidden_size, eps=config.layernorm_epsilon)
  576. self.post_cross_attention_layernorm = BertLayerNorm(
  577. config.hidden_size, eps=config.layernorm_epsilon)
  578. self.intermediate = mpu.ColumnParallelLinear(
  579. config.hidden_size,
  580. config.intermediate_size,
  581. gather_output=False,
  582. init_method=init_method,
  583. )
  584. self.intermediate_act_fn = ACT2FN[config.hidden_act] \
  585. if isinstance(config.hidden_act, str) else config.hidden_act
  586. self.output = mpu.RowParallelLinear(
  587. config.intermediate_size,
  588. config.hidden_size,
  589. input_is_parallel=True,
  590. init_method=output_layer_init_method,
  591. )
  592. self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
  593. self.fp32_layernorm = config.fp32_layernorm
  594. def convert_to_type(tensor):
  595. if self.fp32_layernorm:
  596. return tensor.float()
  597. else:
  598. return tensor
  599. self.type_converter = convert_to_type
  600. # def forward(self, hidden_states, enc_attn_mask, dec_attn_mask):
  601. def forward(self,
  602. hidden_states,
  603. enc_hidden_states,
  604. enc_attn_mask,
  605. dec_attn_mask,
  606. is_infer=False):
  607. residual = hidden_states
  608. previous_type = hidden_states.type()
  609. hidden_states = self.input_layernorm(
  610. self.type_converter(hidden_states))
  611. if self.fp32_layernorm:
  612. hidden_states = hidden_states.type(previous_type)
  613. hidden_states = self.attention(
  614. hidden_states, dec_attn_mask, is_infer=is_infer)
  615. hidden_states = residual + hidden_states
  616. residual = hidden_states
  617. hidden_states = self.post_attention_layernorm(
  618. self.type_converter(hidden_states))
  619. if self.fp32_layernorm:
  620. hidden_states = hidden_states.type(previous_type)
  621. hidden_states = self.cross_attention(hidden_states, enc_hidden_states,
  622. enc_attn_mask)
  623. hidden_states = residual + hidden_states
  624. residual = hidden_states
  625. hidden_states = self.post_cross_attention_layernorm(
  626. self.type_converter(hidden_states))
  627. if self.fp32_layernorm:
  628. hidden_states = hidden_states.type(previous_type)
  629. hidden_states = self.intermediate(hidden_states)
  630. hidden_states = self.intermediate_act_fn(hidden_states)
  631. hidden_states = self.output(hidden_states)
  632. hidden_states = self.dropout(hidden_states)
  633. hidden_states = residual + hidden_states
  634. return hidden_states
  635. class BertDecoder(nn.Module):
  636. def __init__(self, config):
  637. super(BertDecoder, self).__init__()
  638. self.layer = nn.ModuleList(
  639. [DecodeLayer(config) for _ in range(config.dec_hidden_layers)])
  640. self.final_layernorm = BertLayerNorm(
  641. config.hidden_size, eps=config.layernorm_epsilon)
  642. self.fp32_layernorm = config.fp32_layernorm
  643. def forward(self,
  644. hidden_states,
  645. enc_hidden_states,
  646. enc_attn_mask,
  647. dec_attn_mask,
  648. checkpoint_activations=False,
  649. output_all_encoded_layers=False,
  650. is_infer=False):
  651. def custom(start, end):
  652. def custom_forward(*inputs):
  653. layers = self.layer[start:end]
  654. x_ = inputs[0]
  655. for layer in layers:
  656. x_ = layer(
  657. x_,
  658. inputs[1],
  659. inputs[2],
  660. dec_attn_mask * 1,
  661. is_infer=is_infer)
  662. return x_
  663. return custom_forward
  664. pre_enc_hidden = enc_hidden_states.data
  665. if checkpoint_activations:
  666. layer_idx = 0
  667. num_layers = len(self.layer)
  668. chunk_length = 1
  669. while layer_idx < num_layers:
  670. hidden_states = mpu.checkpoint(
  671. custom(layer_idx, layer_idx + chunk_length), hidden_states,
  672. enc_hidden_states, enc_attn_mask * 1)
  673. enc_hidden_states.data = pre_enc_hidden
  674. layer_idx += chunk_length
  675. else:
  676. for i, layer_module in enumerate(self.layer):
  677. hidden_states = layer_module(
  678. hidden_states,
  679. enc_hidden_states,
  680. enc_attn_mask,
  681. dec_attn_mask,
  682. is_infer=is_infer)
  683. previous_type = hidden_states.type()
  684. if self.fp32_layernorm:
  685. hidden_states = hidden_states.float()
  686. hidden_states = self.final_layernorm(hidden_states)
  687. if self.fp32_layernorm:
  688. hidden_states = hidden_states.type(previous_type)
  689. return [hidden_states]
  690. class DecodeModel(PreTrainedBertModel):
  691. def __init__(self, config):
  692. super(DecodeModel, self).__init__(config)
  693. self.decoder = BertDecoder(config)
  694. self.apply(self.init_bert_weights)
  695. def forward(self,
  696. embeddings,
  697. sequence_output,
  698. decode_input_ids,
  699. position_ids=None,
  700. enc_attn_mask=None,
  701. dec_attn_mask=None,
  702. checkpoint_activations=False,
  703. is_infer=False):
  704. extended_attention_mask = enc_attn_mask.unsqueeze(1).unsqueeze(2)
  705. extended_attention_mask = extended_attention_mask.to(
  706. dtype=next(self.decoder.parameters()).dtype) # fp16 compatibility
  707. extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
  708. embedding_output = embeddings(decode_input_ids)
  709. sequence_output = self.decoder(
  710. embedding_output,
  711. sequence_output,
  712. extended_attention_mask,
  713. dec_attn_mask,
  714. checkpoint_activations=False,
  715. is_infer=is_infer)
  716. return sequence_output[-1]
  717. class PalmForPreTraining(PreTrainedBertModel):
  718. def __init__(self, config):
  719. super(PalmForPreTraining, self).__init__(config)
  720. self.bert = BertModel(config)
  721. self.cls = BertPreTrainingHeads(
  722. config, self.bert.embeddings.word_embeddings.weight)
  723. self.decoder = DecodeModel(config)
  724. self.apply(self.init_bert_weights)
  725. def forward(self,
  726. input_ids,
  727. token_type_ids=None,
  728. attention_mask=None,
  729. decode_input_ids=None,
  730. position_ids=None,
  731. decode_attention_mask=None,
  732. lm_labels=None,
  733. checkpoint_activations=False,
  734. is_infer=False,
  735. sequence_output=None,
  736. parallel_output=True):
  737. if sequence_output is None:
  738. sequence_output, pooled_output = self.bert(
  739. input_ids,
  740. token_type_ids,
  741. attention_mask,
  742. output_all_encoded_layers=False,
  743. checkpoint_activations=checkpoint_activations)
  744. prediction_scores, seq_relationship_score = self.cls(
  745. sequence_output, pooled_output)
  746. else:
  747. prediction_scores = None
  748. sequence_output = sequence_output.to(
  749. dtype=next(self.decoder.parameters()).dtype)
  750. if attention_mask is None:
  751. attention_mask = torch.ones_like(input_ids)
  752. decode_output = self.decoder(
  753. self.bert.embeddings,
  754. sequence_output,
  755. decode_input_ids,
  756. position_ids,
  757. attention_mask,
  758. decode_attention_mask,
  759. checkpoint_activations=checkpoint_activations,
  760. is_infer=is_infer)
  761. transformer_output_parallel = mpu.copy_to_model_parallel_region(
  762. decode_output)
  763. logits_parallel = F.linear(transformer_output_parallel,
  764. self.bert.embeddings.word_embeddings.weight)
  765. if parallel_output:
  766. return prediction_scores, logits_parallel
  767. if is_infer:
  768. return prediction_scores, mpu.gather_from_model_parallel_region(
  769. logits_parallel), sequence_output
  770. return prediction_scores, mpu.gather_from_model_parallel_region(
  771. logits_parallel)
  772. class PlugModel(torch.nn.Module):
  773. """
  774. The bare Plug Model transformer outputting raw hidden-states without any specific head on top.
  775. This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
  776. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
  777. and behavior.
  778. Parameters:
  779. config ([`PlugNLGConfig`]): Model configuration class with all the parameters of the model.
  780. Initializing with a config file does not load the weights associated with the model, only the
  781. configuration. Check out the [`~DistributedPlug.initialize_model`] method to load the model weights.
  782. Examples:
  783. >>> # The PLUG model has 27B parameters and usually need to run on multiple GPUs. The example given
  784. >>> # here only initializes a slice of the model on a single GPU.
  785. >>> # Check out the [`~DistributedPipeline.__init__`] method to initialize entire PLUG model.
  786. >>> from modelscope.models.nlp.plug import PlugNLGConfig, PlugModel
  787. >>> # Initializing a Plug configuration
  788. >>> configuration = PlugNLGConfig()
  789. >>> # Initializing a model from the configuration
  790. >>> model = PlugModel(configuration)
  791. """
  792. def __init__(self, config):
  793. super(PlugModel, self).__init__()
  794. self.config = config
  795. self.model = PalmForPreTraining(self.config)
  796. def forward(self,
  797. input_tokens,
  798. token_type_ids=None,
  799. attention_mask=None,
  800. target_tokens=None,
  801. position_ids=None,
  802. decode_attention_mask=None,
  803. checkpoint_activations=False,
  804. is_infer=False,
  805. sequence_output=None,
  806. parallel_output=True):
  807. """
  808. Parameters:
  809. input_tokens (`torch.LongTensor` of shape `(batch_size, input_tokens_length)`):
  810. `input_tokens_length` = `sequence_length`. Indices of input sequence tokens in the vocabulary.
  811. Indices can be obtained using transformers [`BertTokenizer`]. See
  812. [`TextGenerationPreprocessor.__call__`] for details.
  813. token_type_ids (`torch.LongTensor` of shape `(batch_size, input_tokens_length)`, *optional*, defaults to
  814. None):
  815. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
  816. 1]`:
  817. - 0 corresponds to a *sentence A* token,
  818. - 1 corresponds to a *sentence B* token.
  819. attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*, defaults to None):
  820. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  821. - 1 for tokens that are **not masked**,
  822. - 0 for tokens that are **masked**.
  823. target_tokens (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*, defaults to None):
  824. Target token ids(labels) for language modeling. Note that the labels **are shifted** inside the model,
  825. i.e. you can set `target_tokens = input_tokens` Indices are selected in
  826. `[-100, 0, ..., config.vocab_size]` All labels set to `-100` are ignored (masked), the loss is only
  827. computed for labels in `[0, ..., config.vocab_size]`
  828. position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*, defaults to None):
  829. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
  830. `[0, config.max_position_embeddings - 1]`.
  831. decode_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*, defaults
  832. to None):
  833. Mask to avoid performing attention on padding token indices of target tokens. Mask values selected in
  834. `[0, 1]`:
  835. - 1 for tokens that are **not masked**,
  836. - 0 for tokens that are **masked**.
  837. checkpoint_activations (`boolean`, *optional*, defaults to `False`):
  838. Whether gradient checkpointing is activated for this model or not.
  839. is_infer (`boolean`, *optional*, defaults to `False`):
  840. Whether or not to perform single inference.
  841. sequence_output (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*,
  842. defaults to None):
  843. Also known as last_hidden_state. Sequence of hidden-states at the output of the last layer of the
  844. model. A single forward() call can produce one single token. To generate the current token, the
  845. sequence_output generated by the `forward()` of the previous token is required.
  846. parallel_output (`boolean`, *optional*, defaults to `True`):
  847. To parallel return output, or gather it before return.
  848. """
  849. return self.model(
  850. input_tokens,
  851. token_type_ids,
  852. attention_mask,
  853. target_tokens,
  854. position_ids,
  855. decode_attention_mask,
  856. checkpoint_activations=checkpoint_activations,
  857. is_infer=is_infer,
  858. sequence_output=sequence_output,
  859. parallel_output=parallel_output)
  860. def state_dict(self, destination=None, prefix='', keep_vars=False):
  861. return self.model.state_dict(
  862. destination=destination, prefix=prefix, keep_vars=keep_vars)
  863. def load_state_dict(self, state_dict, strict=True):
  864. return self.model.load_state_dict(state_dict, strict=strict)