structbert.py 39 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932
  1. # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team and Alibaba inc.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch BERT model."""
  15. from __future__ import absolute_import, division, print_function
  16. import copy
  17. import math
  18. import json
  19. import numpy as np
  20. import six
  21. import torch
  22. import torch.nn as nn
  23. import torch.nn.functional as F
  24. import torch.utils.checkpoint
  25. from torch.nn import CrossEntropyLoss
  26. def gelu(x):
  27. """Implementation of the gelu activation function.
  28. For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
  29. 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
  30. """
  31. return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
  32. class BertConfig(object):
  33. """Configuration class to store the configuration of a `BertModel`.
  34. """
  35. def __init__(self,
  36. vocab_size,
  37. hidden_size=768,
  38. emb_size=-1,
  39. num_hidden_layers=12,
  40. transformer_type='original',
  41. transition_function='linear',
  42. weighted_transformer=0,
  43. num_rolled_layers=3,
  44. num_attention_heads=12,
  45. intermediate_size=3072,
  46. hidden_act='gelu',
  47. hidden_dropout_prob=0.1,
  48. attention_probs_dropout_prob=0.1,
  49. max_position_embeddings=512,
  50. type_vocab_size=16,
  51. initializer_range=0.02,
  52. attention_type='self',
  53. rezero=False,
  54. pre_ln=False,
  55. squeeze_excitation=False,
  56. transfer_matrix=False,
  57. dim_dropout=False,
  58. roberta_style=False,
  59. set_mask_zero=False,
  60. init_scale=False,
  61. safer_fp16=False,
  62. grad_checkpoint=False):
  63. """Constructs BertConfig.
  64. Args:
  65. vocab_size: Vocabulary size of `inputs_ids` in `BertModel`.
  66. hidden_size: Size of the encoder layers and the pooler layer.
  67. num_hidden_layers: Number of hidden layers in the Transformer encoder.
  68. num_attention_heads: Number of attention heads for each attention layer in
  69. the Transformer encoder.
  70. intermediate_size: The size of the "intermediate" (i.e., feed-forward)
  71. layer in the Transformer encoder.
  72. hidden_act: The non-linear activation function (function or string) in the
  73. encoder and pooler.
  74. hidden_dropout_prob: The dropout probability for all fully connected
  75. layers in the embeddings, encoder, and pooler.
  76. attention_probs_dropout_prob: The dropout ratio for the attention
  77. probabilities.
  78. max_position_embeddings: The maximum sequence length that this model might
  79. ever be used with. Typically set this to something large just in case
  80. (e.g., 512 or 1024 or 2048).
  81. type_vocab_size: The vocabulary size of the `token_type_ids` passed into
  82. `BertModel`.
  83. initializer_range: The stdev of the truncated_normal_initializer for
  84. initializing all weight matrices.
  85. """
  86. self.vocab_size = vocab_size
  87. self.hidden_size = hidden_size
  88. self.emb_size = emb_size
  89. self.num_hidden_layers = num_hidden_layers
  90. self.transformer_type = transformer_type
  91. self.transition_function = transition_function
  92. self.weighted_transformer = weighted_transformer
  93. self.num_rolled_layers = num_rolled_layers
  94. self.num_attention_heads = num_attention_heads
  95. self.hidden_act = hidden_act
  96. self.intermediate_size = intermediate_size
  97. self.hidden_dropout_prob = hidden_dropout_prob
  98. self.attention_probs_dropout_prob = attention_probs_dropout_prob
  99. self.max_position_embeddings = max_position_embeddings
  100. self.type_vocab_size = type_vocab_size
  101. self.initializer_range = initializer_range
  102. self.attention_type = attention_type
  103. self.rezero = rezero
  104. self.pre_ln = pre_ln
  105. self.squeeze_excitation = squeeze_excitation
  106. self.transfer_matrix = transfer_matrix
  107. self.dim_dropout = dim_dropout
  108. self.set_mask_zero = set_mask_zero
  109. self.roberta_style = roberta_style
  110. self.init_scale = init_scale
  111. self.safer_fp16 = safer_fp16
  112. self.grad_checkpoint = grad_checkpoint
  113. @classmethod
  114. def from_dict(cls, json_object):
  115. """Constructs a `BertConfig` from a Python dictionary of parameters."""
  116. config = BertConfig(vocab_size=None)
  117. for (key, value) in six.iteritems(json_object):
  118. config.__dict__[key] = value
  119. return config
  120. @classmethod
  121. def from_json_file(cls, json_file):
  122. """Constructs a `BertConfig` from a json file of parameters."""
  123. with open(json_file, 'r', encoding='utf-8') as reader:
  124. text = reader.read()
  125. return cls.from_dict(json.loads(text))
  126. def to_dict(self):
  127. """Serializes this instance to a Python dictionary."""
  128. output = copy.deepcopy(self.__dict__)
  129. return output
  130. def to_json_string(self):
  131. """Serializes this instance to a JSON string."""
  132. return json.dumps(self.to_dict(), indent=2, sort_keys=True) + '\n'
  133. class BERTLayerNorm(nn.Module):
  134. def __init__(self, config, variance_epsilon=1e-12, special_size=None):
  135. """Construct a layernorm module in the TF style (epsilon inside the square root).
  136. """
  137. super(BERTLayerNorm, self).__init__()
  138. self.config = config
  139. hidden_size = special_size if special_size is not None else config.hidden_size
  140. self.gamma = nn.Parameter(torch.ones(hidden_size))
  141. self.beta = nn.Parameter(torch.zeros(hidden_size))
  142. self.variance_epsilon = variance_epsilon if not config.roberta_style else 1e-5
  143. def forward(self, x):
  144. previous_type = x.type()
  145. if self.config.safer_fp16:
  146. x = x.float()
  147. u = x.mean(-1, keepdim=True)
  148. s = (x - u).pow(2).mean(-1, keepdim=True)
  149. x = (x - u) / torch.sqrt(s + self.variance_epsilon)
  150. if self.config.safer_fp16:
  151. return (self.gamma * x + self.beta).type(previous_type)
  152. else:
  153. return self.gamma * x + self.beta
  154. class BERTEmbeddings(nn.Module):
  155. def __init__(self, config):
  156. super(BERTEmbeddings, self).__init__()
  157. """Construct the embedding module from word, position and token_type embeddings.
  158. """
  159. hidden_size = config.hidden_size if config.emb_size < 0 else config.emb_size
  160. self.word_embeddings = nn.Embedding(
  161. config.vocab_size,
  162. hidden_size,
  163. padding_idx=1 if config.roberta_style else None)
  164. self.position_embeddings = nn.Embedding(
  165. config.max_position_embeddings,
  166. hidden_size,
  167. padding_idx=1 if config.roberta_style else None)
  168. self.token_type_embeddings = nn.Embedding(config.type_vocab_size,
  169. hidden_size)
  170. self.config = config
  171. self.proj = None if config.emb_size < 0 else nn.Linear(
  172. config.emb_size, config.hidden_size)
  173. # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
  174. # any TensorFlow checkpoint file
  175. self.LayerNorm = BERTLayerNorm(config, special_size=hidden_size)
  176. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  177. def forward(self, input_ids, token_type_ids=None, adv_embedding=None):
  178. seq_length = input_ids.size(1)
  179. if not self.config.roberta_style:
  180. position_ids = torch.arange(
  181. seq_length, dtype=torch.long, device=input_ids.device)
  182. position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
  183. else:
  184. mask = input_ids.ne(1).int()
  185. position_ids = (torch.cumsum(mask, dim=1).type_as(mask)
  186. * mask).long() + 1
  187. if token_type_ids is None:
  188. token_type_ids = torch.zeros_like(input_ids)
  189. words_embeddings = self.word_embeddings(
  190. input_ids) if adv_embedding is None else adv_embedding
  191. if self.config.set_mask_zero:
  192. words_embeddings[input_ids == 103] = 0.
  193. position_embeddings = self.position_embeddings(position_ids)
  194. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  195. if not self.config.roberta_style:
  196. embeddings = words_embeddings + position_embeddings + token_type_embeddings
  197. else:
  198. embeddings = words_embeddings + position_embeddings
  199. embeddings = self.LayerNorm(embeddings)
  200. embeddings = self.dropout(embeddings)
  201. if self.proj is not None:
  202. embeddings = self.proj(embeddings)
  203. embeddings = self.dropout(embeddings)
  204. else:
  205. return embeddings, words_embeddings
  206. class BERTFactorizedAttention(nn.Module):
  207. def __init__(self, config):
  208. super(BERTFactorizedAttention, self).__init__()
  209. if config.hidden_size % config.num_attention_heads != 0:
  210. raise ValueError(
  211. 'The hidden size (%d) is not a multiple of the number of attention '
  212. 'heads (%d)' %
  213. (config.hidden_size, config.num_attention_heads))
  214. self.num_attention_heads = config.num_attention_heads
  215. self.attention_head_size = int(config.hidden_size
  216. / config.num_attention_heads)
  217. self.all_head_size = self.num_attention_heads * self.attention_head_size
  218. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  219. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  220. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  221. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  222. def transpose_for_scores(self, x, *size):
  223. new_x_shape = x.size()[:-1] + (self.num_attention_heads,
  224. self.attention_head_size)
  225. x = x.view(*new_x_shape)
  226. return x.permute(size)
  227. def forward(self, hidden_states, attention_mask):
  228. mixed_query_layer = self.query(hidden_states)
  229. mixed_key_layer = self.key(hidden_states)
  230. mixed_value_layer = self.value(hidden_states)
  231. query_layer = self.transpose_for_scores(mixed_query_layer, 0, 2, 3, 1)
  232. key_layer = self.transpose_for_scores(mixed_key_layer, 0, 2, 1, 3)
  233. value_layer = self.transpose_for_scores(mixed_value_layer, 0, 2, 1, 3)
  234. s_attention_scores = query_layer + attention_mask
  235. s_attention_probs = nn.Softmax(dim=-1)(s_attention_scores)
  236. s_attention_probs = self.dropout(s_attention_probs)
  237. c_attention_probs = nn.Softmax(dim=-1)(key_layer)
  238. s_context_layer = torch.matmul(s_attention_probs, value_layer)
  239. context_layer = torch.matmul(c_attention_probs, s_context_layer)
  240. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  241. new_context_layer_shape = context_layer.size()[:-2] + (
  242. self.all_head_size, )
  243. context_layer = context_layer.view(*new_context_layer_shape)
  244. return context_layer
  245. def dim_dropout(x, p=0, dim=-1, training=False):
  246. if not training or p == 0:
  247. return x
  248. a = (1 - p)
  249. b = (x.data.new(x.size()).zero_() + 1)
  250. dropout_mask = torch.bernoulli(a * b)
  251. return dropout_mask * (dropout_mask.size(dim) / torch.sum(
  252. dropout_mask, dim=dim, keepdim=True)) * x
  253. class BERTSelfAttention(nn.Module):
  254. def __init__(self, config):
  255. super(BERTSelfAttention, self).__init__()
  256. if config.hidden_size % config.num_attention_heads != 0:
  257. raise ValueError(
  258. 'The hidden size (%d) is not a multiple of the number of attention '
  259. 'heads (%d)' %
  260. (config.hidden_size, config.num_attention_heads))
  261. self.num_attention_heads = config.num_attention_heads
  262. self.attention_head_size = int(config.hidden_size
  263. / config.num_attention_heads)
  264. self.all_head_size = self.num_attention_heads * self.attention_head_size
  265. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  266. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  267. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  268. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  269. self.config = config
  270. if config.pre_ln:
  271. self.LayerNorm = BERTLayerNorm(config)
  272. def transpose_for_scores(self, x):
  273. new_x_shape = x.size()[:-1] + (self.num_attention_heads,
  274. self.attention_head_size)
  275. x = x.view(*new_x_shape)
  276. return x.permute(0, 2, 1, 3)
  277. def forward(self, hidden_states, attention_mask, head_mask=None):
  278. if self.config.pre_ln:
  279. hidden_states = self.LayerNorm(hidden_states)
  280. mixed_query_layer = self.query(hidden_states)
  281. mixed_key_layer = self.key(hidden_states)
  282. mixed_value_layer = self.value(hidden_states)
  283. query_layer = self.transpose_for_scores(mixed_query_layer)
  284. key_layer = self.transpose_for_scores(mixed_key_layer)
  285. value_layer = self.transpose_for_scores(mixed_value_layer)
  286. # Take the dot product between "query" and "key" to get the raw attention scores.
  287. attention_scores = torch.matmul(query_layer,
  288. key_layer.transpose(-1, -2))
  289. attention_scores = attention_scores / math.sqrt(
  290. self.attention_head_size)
  291. # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
  292. if head_mask is not None and not self.training:
  293. for i, mask in enumerate(head_mask):
  294. if head_mask[i] == 1:
  295. attention_scores[:, i, :, :] = 0.
  296. attention_scores = attention_scores + attention_mask
  297. # Normalize the attention scores to probabilities.
  298. attention_probs = nn.Softmax(dim=-1)(attention_scores)
  299. # This is actually dropping out entire tokens to attend to, which might
  300. # seem a bit unusual, but is taken from the original Transformer paper.
  301. if not self.config.dim_dropout:
  302. attention_probs = self.dropout(attention_probs)
  303. else:
  304. attention_probs = dim_dropout(
  305. attention_probs,
  306. p=self.config.attention_probs_dropout_prob,
  307. dim=-1,
  308. training=self.training)
  309. context_layer = torch.matmul(attention_probs, value_layer)
  310. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  311. new_context_layer_shape = context_layer.size()[:-2] + (
  312. self.all_head_size, )
  313. context_layer = context_layer.view(*new_context_layer_shape)
  314. return context_layer
  315. class BERTSelfOutput(nn.Module):
  316. def __init__(self, config):
  317. super(BERTSelfOutput, self).__init__()
  318. self.config = config
  319. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  320. if not config.pre_ln and not config.rezero:
  321. self.LayerNorm = BERTLayerNorm(config)
  322. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  323. if config.rezero:
  324. self.res_factor = nn.Parameter(
  325. torch.Tensor(1).fill_(0.99).to(
  326. dtype=next(self.parameters()).dtype))
  327. self.factor = nn.Parameter(
  328. torch.ones(1).to(dtype=next(self.parameters()).dtype))
  329. def forward(self, hidden_states, input_tensor):
  330. hidden_states = self.dense(hidden_states)
  331. hidden_states = self.dropout(hidden_states)
  332. if not self.config.rezero and not self.config.pre_ln:
  333. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  334. elif self.config.rezero:
  335. hidden_states = hidden_states + self.factor * input_tensor
  336. else:
  337. pass
  338. return hidden_states
  339. class BERTAttention(nn.Module):
  340. def __init__(self, config):
  341. super(BERTAttention, self).__init__()
  342. if config.attention_type.lower() == 'self':
  343. self.self = BERTSelfAttention(config)
  344. elif config.attention_type.lower() == 'factorized':
  345. self.self = BERTFactorizedAttention(config)
  346. else:
  347. raise ValueError(
  348. 'Attention type must in [self, factorized], but got {}'.format(
  349. config.attention_type))
  350. self.output = BERTSelfOutput(config)
  351. def forward(self, input_tensor, attention_mask, head_mask=None):
  352. self_output = self.self(input_tensor, attention_mask, head_mask)
  353. attention_output = self.output(self_output, input_tensor)
  354. return attention_output
  355. class DepthwiseSeparableConv1d(nn.Module):
  356. def __init__(self,
  357. in_channels,
  358. out_channels,
  359. kernel_size=1,
  360. stride=1,
  361. padding=0,
  362. dilation=1,
  363. bias=False):
  364. super(DepthwiseSeparableConv1d, self).__init__()
  365. padding = (kernel_size - 1) // 2
  366. self.depthwise = nn.Conv1d(
  367. in_channels,
  368. in_channels,
  369. kernel_size,
  370. stride,
  371. padding,
  372. dilation,
  373. groups=in_channels,
  374. bias=bias)
  375. self.pointwise = nn.Conv1d(
  376. in_channels, out_channels, 1, 1, 0, 1, 1, bias=bias)
  377. def forward(self, x):
  378. x = self.depthwise(x)
  379. x = self.pointwise(x)
  380. return x
  381. class BERTIntermediate(nn.Module):
  382. def __init__(self, config):
  383. super(BERTIntermediate, self).__init__()
  384. self.config = config
  385. if self.config.pre_ln:
  386. self.LayerNorm = BERTLayerNorm(config)
  387. self.intermediate_act_fn = gelu
  388. if config.transition_function.lower() == 'linear':
  389. self.dense = nn.Linear(config.hidden_size,
  390. config.intermediate_size)
  391. elif config.transition_function.lower() == 'cnn':
  392. self.cnn = DepthwiseSeparableConv1d(
  393. config.hidden_size, 4 * config.hidden_size, kernel_size=7)
  394. elif config.config.hidden_size.lower() == 'rnn':
  395. raise NotImplementedError(
  396. 'rnn transition function is not implemented yet')
  397. else:
  398. raise ValueError('Only support linear/cnn/rnn')
  399. def forward(self, hidden_states):
  400. if self.config.pre_ln:
  401. hidden_states = self.LayerNorm(hidden_states)
  402. if self.config.transition_function.lower() == 'linear':
  403. hidden_states = self.dense(hidden_states)
  404. elif self.config.transition_function.lower() == 'cnn':
  405. hidden_states = self.cnn(hidden_states.transpose(-1,
  406. -2)).transpose(
  407. -1, -2)
  408. else:
  409. pass
  410. hidden_states = self.intermediate_act_fn(hidden_states)
  411. return hidden_states
  412. class SqueezeExcitationBlock(nn.Module):
  413. def __init__(self, config):
  414. super(SqueezeExcitationBlock, self).__init__()
  415. self.down_sampling = nn.Linear(config.hidden_size,
  416. config.hidden_size // 4)
  417. self.up_sampling = nn.Linear(config.hidden_size // 4,
  418. config.hidden_size)
  419. def forward(self, hidden_states):
  420. squeeze = torch.mean(hidden_states, 1, keepdim=True)
  421. excitation = torch.sigmoid(
  422. self.up_sampling(gelu(self.down_sampling(squeeze))))
  423. return hidden_states * excitation
  424. class BERTOutput(nn.Module):
  425. def __init__(self, config):
  426. super(BERTOutput, self).__init__()
  427. self.config = config
  428. if config.transition_function.lower() == 'linear':
  429. self.dense = nn.Linear(config.intermediate_size,
  430. config.hidden_size)
  431. elif config.transition_function.lower() == 'cnn':
  432. self.cnn = DepthwiseSeparableConv1d(
  433. 4 * config.hidden_size, config.hidden_size, kernel_size=7)
  434. elif config.config.hidden_size.lower() == 'rnn':
  435. raise NotImplementedError(
  436. 'rnn transition function is not implemented yet')
  437. else:
  438. raise ValueError('Only support linear/cnn/rnn')
  439. if not config.pre_ln and not config.rezero:
  440. self.LayerNorm = BERTLayerNorm(config)
  441. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  442. if config.squeeze_excitation:
  443. self.SEblock = SqueezeExcitationBlock(config)
  444. if config.rezero:
  445. self.res_factor = nn.Parameter(
  446. torch.Tensor(1).fill_(0.99).to(
  447. dtype=next(self.parameters()).dtype))
  448. self.factor = nn.Parameter(
  449. torch.ones(1).to(dtype=next(self.parameters()).dtype))
  450. def forward(self, hidden_states, input_tensor):
  451. if self.config.transition_function.lower() == 'linear':
  452. hidden_states = self.dense(hidden_states)
  453. elif self.config.transition_function.lower() == 'cnn':
  454. hidden_states = self.cnn(hidden_states.transpose(-1,
  455. -2)).transpose(
  456. -1, -2)
  457. else:
  458. pass
  459. hidden_states = self.dropout(hidden_states)
  460. if self.config.squeeze_excitation:
  461. hidden_states = self.SEblock(hidden_states)
  462. if not self.config.rezero and not self.config.pre_ln:
  463. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  464. elif self.config.rezero:
  465. hidden_states = hidden_states + self.factor * input_tensor
  466. else:
  467. pass
  468. return hidden_states
  469. class BERTLayer(nn.Module):
  470. def __init__(self, config):
  471. super(BERTLayer, self).__init__()
  472. self.attention = BERTAttention(config)
  473. self.intermediate = BERTIntermediate(config)
  474. self.output = BERTOutput(config)
  475. def forward(self, hidden_states, attention_mask, head_mask=None):
  476. attention_output = self.attention(hidden_states, attention_mask,
  477. head_mask)
  478. intermediate_output = self.intermediate(attention_output)
  479. layer_output = self.output(intermediate_output, attention_output)
  480. return attention_output, layer_output
  481. class BERTWeightedLayer(nn.Module):
  482. def __init__(self, config):
  483. super(BERTWeightedLayer, self).__init__()
  484. self.config = config
  485. self.self = BERTSelfAttention(config)
  486. self.attention_head_size = self.self.attention_head_size
  487. self.w_o = nn.ModuleList([
  488. nn.Linear(self.attention_head_size, config.hidden_size)
  489. for _ in range(config.num_attention_heads)
  490. ])
  491. self.w_kp = torch.rand(config.num_attention_heads)
  492. self.w_kp = nn.Parameter(self.w_kp / self.w_kp.sum())
  493. self.w_a = torch.rand(config.num_attention_heads)
  494. self.w_a = nn.Parameter(self.w_a / self.w_a.sum())
  495. self.intermediate = BERTIntermediate(config)
  496. self.output = nn.Linear(config.intermediate_size, config.hidden_size)
  497. self.LayerNorm = BERTLayerNorm(config)
  498. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  499. def forward(self, hidden_states, attention_mask):
  500. self_output = self.self(hidden_states, attention_mask)
  501. self_outputs = self_output.split(self.self.attention_head_size, dim=-1)
  502. self_outputs = [
  503. self.w_o[i](self_outputs[i]) for i in range(len(self_outputs))
  504. ]
  505. self_outputs = [
  506. self.dropout(self_outputs[i]) for i in range(len(self_outputs))
  507. ]
  508. self_outputs = [
  509. kappa * output for kappa, output in zip(self.w_kp, self_outputs)
  510. ]
  511. self_outputs = [
  512. self.intermediate(self_outputs[i])
  513. for i in range(len(self_outputs))
  514. ]
  515. self_outputs = [
  516. self.output(self_outputs[i]) for i in range(len(self_outputs))
  517. ]
  518. self_outputs = [
  519. self.dropout(self_outputs[i]) for i in range(len(self_outputs))
  520. ]
  521. self_outputs = [
  522. alpha * output for alpha, output in zip(self.w_a, self_outputs)
  523. ]
  524. output = sum(self_outputs)
  525. return self.LayerNorm(hidden_states + output)
  526. class BERTEncoder(nn.Module):
  527. def __init__(self, config):
  528. super(BERTEncoder, self).__init__()
  529. self.layer = nn.ModuleList()
  530. for _ in range(config.num_hidden_layers):
  531. if config.weighted_transformer:
  532. self.layer.append(BERTWeightedLayer(config))
  533. else:
  534. self.layer.append(BERTLayer(config))
  535. if config.rezero:
  536. for index, layer in enumerate(self.layer):
  537. layer.output.res_factor = nn.Parameter(
  538. torch.Tensor(1).fill_(1.).to(
  539. dtype=next(self.parameters()).dtype))
  540. layer.output.factor = nn.Parameter(
  541. torch.Tensor(1).fill_(1).to(
  542. dtype=next(self.parameters()).dtype))
  543. layer.attention.output.res_factor = layer.output.res_factor
  544. layer.attention.output.factor = layer.output.factor
  545. self.config = config
  546. def forward(self,
  547. hidden_states,
  548. attention_mask,
  549. epoch_id=-1,
  550. head_masks=None):
  551. all_encoder_layers = [hidden_states]
  552. if epoch_id != -1:
  553. detach_index = int(len(self.layer) / 3) * (2 - epoch_id) - 1
  554. else:
  555. detach_index = -1
  556. for index, layer_module in enumerate(self.layer):
  557. if head_masks is None:
  558. if not self.config.grad_checkpoint:
  559. self_out, hidden_states = layer_module(
  560. hidden_states, attention_mask, None)
  561. else:
  562. self_out, hidden_states = torch.utils.checkpoint.checkpoint(
  563. layer_module, hidden_states, attention_mask, None)
  564. else:
  565. self_out, hidden_states = layer_module(hidden_states,
  566. attention_mask,
  567. head_masks[index])
  568. if detach_index == index:
  569. hidden_states.detach_()
  570. all_encoder_layers.append(self_out)
  571. all_encoder_layers.append(hidden_states)
  572. return all_encoder_layers
  573. class BERTEncoderRolled(nn.Module):
  574. def __init__(self, config):
  575. super(BERTEncoderRolled, self).__init__()
  576. layer = BERTLayer(config)
  577. self.config = config
  578. self.layer = nn.ModuleList(
  579. [copy.deepcopy(layer) for _ in range(config.num_rolled_layers)])
  580. def forward(self,
  581. hidden_states,
  582. attention_mask,
  583. epoch_id=-1,
  584. head_masks=None):
  585. all_encoder_layers = [hidden_states]
  586. for i in range(self.config.num_hidden_layers):
  587. if self.config.transformer_type.lower() == 'universal':
  588. hidden_states = self.layer[i % self.config.num_rolled_layers](
  589. hidden_states, attention_mask)
  590. elif self.config.transformer_type.lower() == 'albert':
  591. a = i // (
  592. self.config.num_hidden_layers
  593. // self.config.num_rolled_layers)
  594. hidden_states = self.layer[a](hidden_states, attention_mask)
  595. all_encoder_layers.append(hidden_states)
  596. return all_encoder_layers
  597. class BERTEncoderACT(nn.Module):
  598. def __init__(self, config):
  599. super(BERTEncoderACT, self).__init__()
  600. self.layer = BERTLayer(config)
  601. p = nn.Linear(config.hidden_size, 1)
  602. self.p = nn.ModuleList(
  603. [copy.deepcopy(p) for _ in range(config.num_hidden_layers)])
  604. # Following act paper, set bias init ones
  605. for module in self.p:
  606. module.bias.data.fill_(1.)
  607. self.config = config
  608. self.act_max_steps = config.num_hidden_layers
  609. self.threshold = 0.99
  610. def should_continue(self, halting_probability, n_updates):
  611. return (halting_probability.lt(self.threshold).__and__(
  612. n_updates.lt(self.act_max_steps))).any()
  613. def forward(self, hidden_states, attention_mask):
  614. all_encoder_layers = [hidden_states]
  615. batch_size, seq_len, hdim = hidden_states.size()
  616. halting_probability = torch.zeros(batch_size, seq_len).cuda()
  617. remainders = torch.zeros(batch_size, seq_len).cuda()
  618. n_updates = torch.zeros(batch_size, seq_len).cuda()
  619. for i in range(self.act_max_steps):
  620. p = torch.sigmoid(self.p[i](hidden_states).squeeze(2))
  621. still_running = halting_probability.lt(1.0).float()
  622. new_halted = (halting_probability + p * still_running).gt(
  623. self.threshold).float() * still_running
  624. still_running = (halting_probability + p * still_running).le(
  625. self.threshold).float() * still_running
  626. halting_probability = halting_probability + p * still_running
  627. remainders = remainders + new_halted * (1 - halting_probability)
  628. halting_probability = halting_probability + new_halted * remainders
  629. n_updates = n_updates + still_running + new_halted
  630. update_weights = (p * still_running
  631. + new_halted * remainders).unsqueeze(2)
  632. transformed_states = self.layer(hidden_states, attention_mask)
  633. hidden_states = transformed_states * update_weights + hidden_states * (
  634. 1 - update_weights)
  635. all_encoder_layers.append(hidden_states)
  636. if not self.should_continue(halting_probability, n_updates):
  637. break
  638. return all_encoder_layers, torch.mean(n_updates + remainders)
  639. class BERTPooler(nn.Module):
  640. def __init__(self, config):
  641. super(BERTPooler, self).__init__()
  642. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  643. self.activation = nn.Tanh()
  644. def forward(self, hidden_states):
  645. # We "pool" the model by simply taking the hidden state corresponding
  646. # to the first token.
  647. first_token_tensor = hidden_states[:, 0]
  648. pooled_output = self.dense(first_token_tensor)
  649. pooled_output = self.activation(pooled_output)
  650. return pooled_output
  651. class BertModel(nn.Module):
  652. """BERT model ("Bidirectional Embedding Representations from a Transformer").
  653. Example:
  654. >>> # Already been converted into WordPiece token ids
  655. >>> input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
  656. >>> input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
  657. >>> token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]])
  658. >>> config = modeling.BertConfig(vocab_size=32000, hidden_size=512,
  659. >>> num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
  660. >>> model = modeling.BertModel(config=config)
  661. >>> all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
  662. """
  663. def __init__(self, config: BertConfig):
  664. """Constructor for BertModel.
  665. Args:
  666. config: `BertConfig` instance.
  667. """
  668. super(BertModel, self).__init__()
  669. self.config = config
  670. self.embeddings = BERTEmbeddings(config)
  671. if config.transformer_type.lower() == 'original':
  672. self.encoder = BERTEncoder(config)
  673. elif config.transformer_type.lower() == 'universal':
  674. self.encoder = BERTEncoderRolled(config)
  675. elif config.transformer_type.lower() == 'albert':
  676. self.encoder = BERTEncoderRolled(config)
  677. elif config.transformer_type.lower() == 'act':
  678. self.encoder = BERTEncoderACT(config)
  679. elif config.transformer_type.lower() == 'textnas':
  680. from textnas_final import input_dict, op_dict, skip_dict
  681. self.encoder = TextNASEncoder(config, op_dict, input_dict,
  682. skip_dict)
  683. else:
  684. raise ValueError('Not support transformer type: {}'.format(
  685. config.transformer_type.lower()))
  686. self.pooler = BERTPooler(config)
  687. def forward(self,
  688. input_ids,
  689. token_type_ids=None,
  690. attention_mask=None,
  691. epoch_id=-1,
  692. head_masks=None,
  693. adv_embedding=None):
  694. if attention_mask is None:
  695. attention_mask = torch.ones_like(input_ids)
  696. if token_type_ids is None:
  697. token_type_ids = torch.zeros_like(input_ids)
  698. # We create a 3D attention mask from a 2D tensor mask.
  699. # Sizes are [batch_size, 1, 1, to_seq_length]
  700. # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
  701. # this attention mask is more simple than the triangular masking of causal attention
  702. # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
  703. extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
  704. # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
  705. # masked positions, this operation will create a tensor which is 0.0 for
  706. # positions we want to attend and -10000.0 for masked positions.
  707. # Since we are adding it to the raw scores before the softmax, this is
  708. # effectively the same as removing these entirely.
  709. extended_attention_mask = extended_attention_mask.to(
  710. dtype=next(self.parameters()).dtype) # fp16 compatibility
  711. extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
  712. embedding_output, word_embeddings = self.embeddings(
  713. input_ids, token_type_ids, adv_embedding)
  714. if self.config.transformer_type.lower() == 'act':
  715. all_encoder_layers, act_loss = self.encoder(
  716. embedding_output, extended_attention_mask)
  717. elif self.config.transformer_type.lower() == 'reformer':
  718. sequence_output = self.encoder(embedding_output)
  719. all_encoder_layers = [sequence_output, sequence_output]
  720. else:
  721. all_encoder_layers = self.encoder(embedding_output,
  722. extended_attention_mask,
  723. epoch_id, head_masks)
  724. all_encoder_layers.insert(0, word_embeddings)
  725. sequence_output = all_encoder_layers[-1]
  726. if not self.config.safer_fp16:
  727. pooled_output = self.pooler(sequence_output)
  728. else:
  729. pooled_output = sequence_output[:, 0]
  730. return all_encoder_layers, pooled_output
  731. class BertForSequenceClassificationMultiTask(nn.Module):
  732. """BERT model for classification.
  733. This module is composed of the BERT model with a linear layer on top of
  734. the pooled output.
  735. Example:
  736. >>> # Already been converted into WordPiece token ids
  737. >>> input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
  738. >>> input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
  739. >>> token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]])
  740. >>> config = BertConfig(vocab_size=32000, hidden_size=512,
  741. >>> num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
  742. >>> num_labels = 2
  743. >>> model = BertForSequenceClassification(config, num_labels)
  744. >>> logits = model(input_ids, token_type_ids, input_mask)
  745. """
  746. def __init__(self, config, label_list, core_encoder):
  747. super(BertForSequenceClassificationMultiTask, self).__init__()
  748. if core_encoder.lower() == 'bert':
  749. self.bert = BertModel(config)
  750. elif core_encoder.lower() == 'lstm':
  751. self.bert = LSTMModel(config)
  752. else:
  753. raise ValueError(
  754. 'Only support lstm or bert, but got {}'.format(core_encoder))
  755. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  756. self.classifier = nn.ModuleList()
  757. for label in label_list:
  758. self.classifier.append(nn.Linear(config.hidden_size, len(label)))
  759. self.label_list = label_list
  760. def init_weights(module):
  761. if isinstance(module, (nn.Linear, nn.Embedding)):
  762. # Slightly different from the TF version which uses truncated_normal for initialization
  763. # cf https://github.com/pytorch/pytorch/pull/5617
  764. module.weight.data.normal_(
  765. mean=0.0, std=config.initializer_range)
  766. elif isinstance(module, BERTLayerNorm):
  767. module.beta.data.normal_(
  768. mean=0.0, std=config.initializer_range)
  769. module.gamma.data.normal_(
  770. mean=0.0, std=config.initializer_range)
  771. if isinstance(module, nn.Linear):
  772. module.bias.data.zero_()
  773. self.apply(init_weights)
  774. def forward(self,
  775. input_ids,
  776. token_type_ids,
  777. attention_mask,
  778. labels=None,
  779. labels_index=None,
  780. epoch_id=-1,
  781. head_masks=None,
  782. adv_embedding=None,
  783. return_embedding=False,
  784. loss_weight=None):
  785. all_encoder_layers, pooled_output = self.bert(input_ids,
  786. token_type_ids,
  787. attention_mask, epoch_id,
  788. head_masks,
  789. adv_embedding)
  790. pooled_output = self.dropout(pooled_output)
  791. logits = [classifier(pooled_output) for classifier in self.classifier]
  792. if labels is not None:
  793. loss_fct = CrossEntropyLoss(reduction='none')
  794. regression_loss_fct = nn.MSELoss(reduction='none')
  795. labels_lst = torch.unbind(labels, 1)
  796. loss_lst = []
  797. for index, (label, logit) in enumerate(zip(labels_lst, logits)):
  798. if len(self.label_list[index]) != 1:
  799. loss = loss_fct(logit, label.long())
  800. else:
  801. loss = regression_loss_fct(logit.squeeze(-1), label)
  802. labels_mask = (labels_index == index).to(
  803. dtype=next(self.parameters()).dtype)
  804. if loss_weight is not None:
  805. loss = loss * loss_weight[index]
  806. loss = torch.mean(loss * labels_mask)
  807. loss_lst.append(loss)
  808. if not return_embedding:
  809. return sum(loss_lst), logits
  810. else:
  811. return sum(loss_lst), logits, all_encoder_layers[0]
  812. else:
  813. return logits