backbone.py 44 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149
  1. # Copyright 2021-2022 The Alibaba DAMO NLP Team Authors.
  2. # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
  3. # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. import copy
  17. import math
  18. import os
  19. from dataclasses import dataclass
  20. from typing import Any, Dict, List, Optional, Union
  21. import numpy as np
  22. import torch
  23. import torch.nn.functional as F
  24. from torch import Tensor, nn
  25. from torch.nn.init import xavier_uniform_
  26. from transformers import (BertConfig, BertModel, BertTokenizer, RobertaConfig,
  27. RobertaModel, RobertaTokenizer)
  28. from transformers.activations import ACT2FN
  29. from transformers.modeling_utils import PreTrainedModel
  30. from modelscope.utils import logger as logging
  31. from .configuration import PlugConfig
  32. CONFIG_NAME = 'config.json'
  33. WEIGHTS_NAME = 'pytorch_model.bin'
  34. class MultiHeadedAttention(nn.Module): # SelfAttention
  35. """
  36. Multi-Head Attention module from
  37. "Attention is All You Need"
  38. :cite:`DBLP:journals/corr/VaswaniSPUJGKP17`.
  39. Similar to standard `dot` attention but uses
  40. multiple attention distributions simultaneously
  41. to select relevant items.
  42. .. mermaid::
  43. graph BT
  44. A[key]
  45. B[value]
  46. C[query]
  47. O[output]
  48. subgraph Attn
  49. D[Attn 1]
  50. E[Attn 2]
  51. F[Attn N]
  52. end
  53. A --> D
  54. C --> D
  55. A --> E
  56. C --> E
  57. A --> F
  58. C --> F
  59. D --> O
  60. E --> O
  61. F --> O
  62. B --> O
  63. Also includes several additional tricks.
  64. Args:
  65. head_count (int): number of parallel heads
  66. model_dim (int): the dimension of keys/values/queries,
  67. must be divisible by head_count
  68. dropout (float): dropout parameter
  69. """
  70. def __init__(self,
  71. head_count,
  72. model_dim,
  73. dropout=0.1,
  74. use_final_linear=True):
  75. assert model_dim % head_count == 0
  76. self.dim_per_head = model_dim // head_count
  77. self.model_dim = model_dim
  78. super().__init__()
  79. self.head_count = head_count
  80. self.linear_keys = nn.Linear(model_dim, head_count * self.dim_per_head)
  81. self.linear_values = nn.Linear(model_dim,
  82. head_count * self.dim_per_head)
  83. self.linear_query = nn.Linear(model_dim,
  84. head_count * self.dim_per_head)
  85. self.softmax = nn.Softmax(dim=-1)
  86. self.dropout = nn.Dropout(dropout)
  87. self.use_final_linear = use_final_linear
  88. if (self.use_final_linear):
  89. self.final_linear = nn.Linear(model_dim, model_dim)
  90. def forward(self,
  91. key,
  92. value,
  93. query,
  94. mask=None,
  95. layer_cache=None,
  96. type=None,
  97. predefined_graph_1=None,
  98. return_attn=False):
  99. """
  100. Compute the context vector and the attention vectors.
  101. Args:
  102. key (`FloatTensor`): set of `key_len`
  103. key vectors `[batch, key_len, dim]`
  104. value (`FloatTensor`): set of `key_len`
  105. value vectors `[batch, key_len, dim]`
  106. query (`FloatTensor`): set of `query_len`
  107. query vectors `[batch, query_len, dim]`
  108. mask: binary mask indicating which keys have
  109. non-zero attention `[batch, query_len, key_len]`
  110. Returns:
  111. (`FloatTensor`, `FloatTensor`) :
  112. * output context vectors `[batch, query_len, dim]`
  113. * one of the attention vectors `[batch, query_len, key_len]`
  114. """
  115. batch_size = key.size(0)
  116. dim_per_head = self.dim_per_head
  117. head_count = self.head_count
  118. def shape(x):
  119. """ projection """
  120. return x.view(batch_size, -1, head_count, dim_per_head) \
  121. .transpose(1, 2)
  122. def unshape(x):
  123. """ compute context """
  124. return x.transpose(1, 2).contiguous() \
  125. .view(batch_size, -1, head_count * dim_per_head)
  126. # 1) Project key, value, and query.
  127. if layer_cache is not None:
  128. if type == 'self':
  129. query, key, value = self.linear_query(query), self.linear_keys(
  130. query), self.linear_values(query)
  131. key = shape(key)
  132. value = shape(value)
  133. device = key.device
  134. if layer_cache['self_keys'] is not None:
  135. key = torch.cat((layer_cache['self_keys'].to(device), key),
  136. dim=2)
  137. if layer_cache['self_values'] is not None:
  138. value = torch.cat(
  139. (layer_cache['self_values'].to(device), value), dim=2)
  140. layer_cache['self_keys'] = key
  141. layer_cache['self_values'] = value
  142. elif type == 'context':
  143. query = self.linear_query(query)
  144. if layer_cache['memory_keys'] is None:
  145. key, value = self.linear_keys(key), self.linear_values(
  146. value)
  147. key = shape(key)
  148. value = shape(value)
  149. else:
  150. key, value = layer_cache['memory_keys'], layer_cache[
  151. 'memory_values']
  152. layer_cache['memory_keys'] = key
  153. layer_cache['memory_values'] = value
  154. else:
  155. key = self.linear_keys(key)
  156. value = self.linear_values(value)
  157. query = self.linear_query(query)
  158. key = shape(key)
  159. value = shape(value)
  160. query = shape(query)
  161. # 2) Calculate and scale scores.
  162. query = query / math.sqrt(dim_per_head)
  163. scores = torch.matmul(query, key.transpose(2, 3))
  164. if mask is not None:
  165. mask = mask.unsqueeze(1).expand_as(scores)
  166. scores = scores.masked_fill(mask, float('-inf'))
  167. # 3) Apply attention dropout and compute context vectors.
  168. attn = self.softmax(scores)
  169. if (predefined_graph_1 is not None):
  170. attn_masked = attn[:, -1] * predefined_graph_1
  171. attn_masked = attn_masked / (
  172. torch.sum(attn_masked, 2).unsqueeze(2) + 1e-9)
  173. attn = torch.cat([attn[:, :-1], attn_masked.unsqueeze(1)], 1)
  174. drop_attn = self.dropout(attn)
  175. if (self.use_final_linear):
  176. context = unshape(torch.matmul(drop_attn, value))
  177. output = self.final_linear(context)
  178. if return_attn:
  179. return output, attn
  180. else:
  181. return output
  182. else:
  183. context = torch.matmul(drop_attn, value)
  184. if return_attn:
  185. return context, attn
  186. else:
  187. return context
  188. class PositionwiseFeedForward(nn.Module): # Output
  189. """ A two-layer Feed-Forward-Network with residual layer norm.
  190. Args:
  191. d_model (int): the size of input for the first-layer of the FFN.
  192. d_ff (int): the hidden layer size of the second-layer
  193. of the FNN.
  194. dropout (float): dropout probability in :math:`[0, 1)`.
  195. """
  196. def __init__(self, d_model, d_ff, dropout=0.1):
  197. super().__init__()
  198. self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
  199. self.w_1 = nn.Linear(d_model, d_ff)
  200. self.actv = ACT2FN['gelu_new']
  201. self.dropout_1 = nn.Dropout(dropout)
  202. self.w_2 = nn.Linear(d_ff, d_model)
  203. self.dropout_2 = nn.Dropout(dropout)
  204. def forward(self, x):
  205. inter = self.dropout_1(self.actv(self.w_1(self.layer_norm(x))))
  206. output = self.dropout_2(self.w_2(inter))
  207. return output + x
  208. class TransformerDecoderLayer(nn.Module): # Layer
  209. """
  210. Args:
  211. d_model (int): the dimension of keys/values/queries in
  212. MultiHeadedAttention, also the input size of
  213. the first-layer of the PositionwiseFeedForward.
  214. heads (int): the number of heads for MultiHeadedAttention.
  215. d_ff (int): the second-layer of the PositionwiseFeedForward.
  216. dropout (float): dropout probability(0-1.0).
  217. self_attn_type (string): type of self-attention scaled-dot, average
  218. """
  219. MAX_SIZE = 5000
  220. def __init__(self, d_model, heads, d_ff, dropout):
  221. super().__init__()
  222. self.self_attn = MultiHeadedAttention(heads, d_model, dropout=dropout)
  223. self.context_attn = MultiHeadedAttention(
  224. heads, d_model, dropout=dropout)
  225. self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
  226. self.layer_norm_1 = nn.LayerNorm(d_model, eps=1e-6)
  227. self.layer_norm_2 = nn.LayerNorm(d_model, eps=1e-6)
  228. self.drop = nn.Dropout(dropout)
  229. mask = self._get_attn_subsequent_mask(self.MAX_SIZE)
  230. # Register self.mask as a buffer in TransformerDecoderLayer, so
  231. # it gets TransformerDecoderLayer's cuda behavior automatically.
  232. self.register_buffer('mask', mask)
  233. def forward(self,
  234. inputs,
  235. memory_bank,
  236. src_pad_mask,
  237. tgt_pad_mask,
  238. previous_input=None,
  239. layer_cache=None,
  240. step=None):
  241. """
  242. Args:
  243. inputs (`FloatTensor`): `[batch_size x 1 x model_dim]`
  244. memory_bank (`FloatTensor`): `[batch_size x src_len x model_dim]`
  245. src_pad_mask (`LongTensor`): `[batch_size x 1 x src_len]`
  246. tgt_pad_mask (`LongTensor`): `[batch_size x 1 x 1]`
  247. Returns:
  248. (`FloatTensor`, `FloatTensor`, `FloatTensor`):
  249. * output `[batch_size x 1 x model_dim]`
  250. * attn `[batch_size x 1 x src_len]`
  251. * all_input `[batch_size x current_step x model_dim]`
  252. """
  253. dec_mask = torch.gt(
  254. tgt_pad_mask.type(torch.uint8)
  255. + self.mask[:, :tgt_pad_mask.size(1), :tgt_pad_mask.size(1)].type(
  256. torch.uint8), 0)
  257. input_norm = self.layer_norm_1(inputs)
  258. all_input = input_norm
  259. if previous_input is not None:
  260. all_input = torch.cat((previous_input, input_norm), dim=1)
  261. dec_mask = None
  262. query = self.self_attn(
  263. all_input,
  264. all_input,
  265. input_norm,
  266. mask=dec_mask,
  267. layer_cache=layer_cache,
  268. type='self')
  269. query = self.drop(query) + inputs
  270. query_norm = self.layer_norm_2(query)
  271. mid, attn = self.context_attn(
  272. memory_bank,
  273. memory_bank,
  274. query_norm,
  275. mask=src_pad_mask,
  276. layer_cache=layer_cache,
  277. type='context',
  278. return_attn=True)
  279. output = self.feed_forward(self.drop(mid) + query)
  280. return output, attn, all_input
  281. def _get_attn_subsequent_mask(self, size):
  282. """
  283. Get an attention mask to avoid using the subsequent info.
  284. Args:
  285. size: int
  286. Returns:
  287. (`LongTensor`):
  288. * subsequent_mask `[1 x size x size]`
  289. """
  290. attn_shape = (1, size, size)
  291. subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
  292. subsequent_mask = torch.from_numpy(subsequent_mask)
  293. return subsequent_mask
  294. class PositionalEncoding(nn.Module):
  295. def __init__(self, dropout, dim, max_len=5000):
  296. super().__init__()
  297. pe = torch.zeros(max_len, dim)
  298. position = torch.arange(0, max_len).unsqueeze(1)
  299. div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.float)
  300. * -(math.log(10000.0) / dim)))
  301. pe[:, 0::2] = torch.sin(position.float() * div_term)
  302. pe[:, 1::2] = torch.cos(position.float() * div_term)
  303. pe = pe.unsqueeze(0)
  304. self.register_buffer('pe', pe)
  305. self.dropout = nn.Dropout(dropout)
  306. self.dim = dim
  307. def forward(self, emb, step=None):
  308. emb = emb * math.sqrt(self.dim)
  309. if (step):
  310. emb = emb + self.pe[:, step][:, None, :]
  311. else:
  312. emb = emb + self.pe[:, :emb.size(1)]
  313. emb = self.dropout(emb)
  314. return emb
  315. def get_emb(self, emb):
  316. return self.pe[:, :emb.size(1)]
  317. class TransformerDecoderState:
  318. def __init__(self, src: Tensor, cache_num_layers: int = -1):
  319. self.src: Tensor = src
  320. self.previous_input: Tensor = None
  321. self.previous_layer_inputs: Tensor = None
  322. self.cache: Optional[Dict[str, Any]] = None
  323. if cache_num_layers != -1:
  324. self._init_cache(cache_num_layers)
  325. def update_state(self, new_input, previous_layer_inputs):
  326. self.previous_input = new_input
  327. self.previous_layer_inputs = previous_layer_inputs
  328. self.cache = None
  329. def _init_cache(self, num_layers):
  330. self.cache = {}
  331. for layer in range(num_layers):
  332. layer_cache = {'memory_keys': None, 'memory_values': None}
  333. layer_cache['self_keys'] = None
  334. layer_cache['self_values'] = None
  335. self.cache['layer_{}'.format(layer)] = layer_cache
  336. def map_batch_fn(self, fn):
  337. def _recursive_map(struct, batch_dim=0):
  338. for k, v in struct.items():
  339. if v is not None:
  340. if isinstance(v, dict):
  341. _recursive_map(v)
  342. else:
  343. struct[k] = fn(v, batch_dim)
  344. self.src = fn(self.src, 0)
  345. if self.cache is not None:
  346. _recursive_map(self.cache)
  347. class TransformerDecoder(nn.Module): # Decoder
  348. """
  349. The Transformer decoder from "Attention is All You Need".
  350. .. mermaid::
  351. graph BT
  352. A[input]
  353. B[multi-head self-attn]
  354. BB[multi-head src-attn]
  355. C[feed forward]
  356. O[output]
  357. A --> B
  358. B --> BB
  359. BB --> C
  360. C --> O
  361. Args:
  362. num_layers (int): number of encoder layers.
  363. d_model (int): size of the model
  364. heads (int): number of heads
  365. d_ff (int): size of the inner FF layer
  366. dropout (float): dropout parameters
  367. embeddings (:obj:`onmt.modules.Embeddings`):
  368. embeddings to use, should have positional encodings
  369. attn_type (str): if using a separate copy attention
  370. """
  371. decoder_type = 'transformer'
  372. def __init__(self, num_layers, d_model, heads, d_ff, dropout, embeddings):
  373. super().__init__()
  374. # Basic attributes.
  375. self.num_layers = num_layers
  376. self.embeddings = embeddings
  377. self.pos_emb = PositionalEncoding(dropout,
  378. self.embeddings.embedding_dim)
  379. # Build TransformerDecoder.
  380. self.transformer_layers = nn.ModuleList([
  381. TransformerDecoderLayer(d_model, heads, d_ff, dropout)
  382. for _ in range(num_layers)
  383. ])
  384. self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
  385. self.state = None
  386. def forward(self,
  387. state: TransformerDecoderState,
  388. tgt: Tensor,
  389. memory_bank: Tensor,
  390. step: int = None,
  391. memory_masks: Tensor = None):
  392. src_words = state.src
  393. tgt_words = tgt
  394. src_batch, src_len = src_words.size()
  395. tgt_batch, tgt_len = tgt_words.size()
  396. # Run the forward pass of the TransformerDecoder.
  397. # emb = self.embeddings(tgt, step=step)
  398. emb = self.embeddings(tgt)
  399. assert emb.dim() == 3 # len x batch x embedding_dim
  400. output = self.pos_emb(emb, step)
  401. src_memory_bank = memory_bank
  402. padding_idx = self.embeddings.padding_idx
  403. tgt_pad_mask = tgt_words.data.eq(padding_idx).unsqueeze(1) \
  404. .expand(tgt_batch, tgt_len, tgt_len)
  405. if (memory_masks is not None):
  406. src_len = memory_masks.size(-1)
  407. src_pad_mask = memory_masks.expand(src_batch, tgt_len, src_len)
  408. else:
  409. src_pad_mask = src_words.data.eq(padding_idx).unsqueeze(1) \
  410. .expand(src_batch, tgt_len, src_len)
  411. if state.cache is None:
  412. saved_inputs = []
  413. attns = []
  414. for i in range(self.num_layers):
  415. prev_layer_input = None
  416. if state.cache is None:
  417. if state.previous_input is not None:
  418. prev_layer_input = state.previous_layer_inputs[i]
  419. output, attn, all_input \
  420. = self.transformer_layers[i](
  421. output, src_memory_bank,
  422. src_pad_mask, tgt_pad_mask,
  423. previous_input=prev_layer_input,
  424. layer_cache=state.cache['layer_{}'.format(i)]
  425. if state.cache is not None else None,
  426. step=step)
  427. if state.cache is None:
  428. saved_inputs.append(all_input)
  429. attns.append(attn)
  430. if state.cache is None:
  431. saved_inputs = torch.stack(saved_inputs)
  432. output = self.layer_norm(output)
  433. # Process the result and update the attentions.
  434. if state.cache is None:
  435. state.update_state(tgt, saved_inputs)
  436. return output, attns, state
  437. class PlugPointerGenerator(nn.Module):
  438. def __init__(self, hidden_size, vocab_size):
  439. super().__init__()
  440. self.dense = nn.Linear(hidden_size, vocab_size)
  441. self.gen_func = nn.LogSoftmax(-1)
  442. def forward(self, x):
  443. x = self.dense(x)
  444. x = self.gen_func(x)
  445. return x
  446. class PlugPreTrainedModel(PreTrainedModel):
  447. """
  448. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  449. models.
  450. """
  451. config_class = PlugConfig
  452. base_model_prefix = 'plug'
  453. @classmethod
  454. def from_pretrained(
  455. cls, pretrained_model_name_or_path: Optional[Union[str,
  456. os.PathLike]]):
  457. config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
  458. config = PlugConfig.from_json_file(config_file) if os.path.isfile(
  459. config_file) else PlugConfig()
  460. config.encoder_pth = os.path.join(pretrained_model_name_or_path,
  461. config.encoder_pth)
  462. checkpoint_file = os.path.join(pretrained_model_name_or_path,
  463. WEIGHTS_NAME)
  464. checkpoint = torch.load(checkpoint_file) if os.path.isfile(
  465. checkpoint_file) else None
  466. return cls(config, checkpoint)
  467. class PlugModel(PlugPreTrainedModel): # Model
  468. def __init__(self, config, checkpoint=None):
  469. super().__init__(config)
  470. self.config = config
  471. if config.encoder == 'bert' or config.encoder == 'zh_bert':
  472. self.bert = BertModel(
  473. BertConfig.from_pretrained(config.encoder_pth))
  474. elif config.encoder == 'roberta':
  475. self.bert = RobertaModel(
  476. RobertaConfig.from_pretrained(config.encoder_pth))
  477. if (config.max_pos > 512):
  478. my_pos_embeddings = nn.Embedding(
  479. config.max_pos, self.bert.model.config.hidden_size)
  480. my_pos_embeddings.weight.data[:
  481. 512] = self.bert.embeddings.position_embeddings.weight.data
  482. my_pos_embeddings.weight.data[
  483. 512:] = self.bert.embeddings.position_embeddings.weight.data[
  484. -1][None, :].repeat(config.max_pos - 512, 1)
  485. self.bert.model.embeddings.position_embeddings = my_pos_embeddings
  486. self.vocab_size = self.bert.config.vocab_size
  487. tgt_embeddings = nn.Embedding(
  488. self.vocab_size,
  489. self.bert.config.hidden_size,
  490. padding_idx=1 if config.encoder == 'roberta' else 0)
  491. if config.share_emb:
  492. tgt_embeddings.weight = copy.deepcopy(
  493. self.bert.model.embeddings.word_embeddings.weight)
  494. self.decoder = TransformerDecoder(
  495. config.dec_layers,
  496. config.dec_hidden_size,
  497. heads=config.dec_heads,
  498. d_ff=config.dec_ff_size,
  499. dropout=config.dec_dropout,
  500. embeddings=tgt_embeddings)
  501. self.generator = PlugPointerGenerator(config.dec_hidden_size,
  502. self.vocab_size)
  503. self.generator.dense.weight = self.decoder.embeddings.weight
  504. if checkpoint is not None:
  505. for key in list(checkpoint['model'].keys()):
  506. if key.startswith('module.'):
  507. checkpoint['model'][key.replace(
  508. 'module.', '')] = checkpoint['model'][key]
  509. checkpoint['model'].pop(key)
  510. if key.startswith('plug.'):
  511. checkpoint['model'][key.replace(
  512. 'plug.', '')] = checkpoint['model'][key]
  513. checkpoint['model'].pop(key)
  514. msg = self.load_state_dict(checkpoint['model'], strict=False)
  515. print(msg)
  516. else:
  517. for module in self.decoder.modules():
  518. if isinstance(module, (nn.Linear, nn.Embedding)):
  519. module.weight.data.normal_(mean=0.0, std=0.02)
  520. elif isinstance(module, nn.LayerNorm):
  521. module.bias.data.zero_()
  522. module.weight.data.fill_(1.0)
  523. if isinstance(module, nn.Linear) and module.bias is not None:
  524. module.bias.data.zero_()
  525. for p in self.generator.parameters():
  526. if p.dim() > 1:
  527. xavier_uniform_(p)
  528. else:
  529. p.data.zero_()
  530. if config.use_bert_emb:
  531. if config.encoder == 'roberta':
  532. tgt_embeddings = nn.Embedding(
  533. self.vocab_size,
  534. self.bert.config.hidden_size,
  535. padding_idx=1)
  536. else:
  537. tgt_embeddings = nn.Embedding(
  538. self.vocab_size,
  539. self.bert.config.hidden_size,
  540. padding_idx=0)
  541. tgt_embeddings.weight = copy.deepcopy(
  542. self.bert.embeddings.word_embeddings.weight)
  543. self.decoder.embeddings = tgt_embeddings
  544. self.generator.dense.weight = self.decoder.embeddings.weight
  545. def forward(self, src, tgt, mask_src, token_type_ids):
  546. top_vec, _ = self.bert(
  547. src, mask_src, token_type_ids=token_type_ids, return_dict=False)
  548. state = TransformerDecoderState(src)
  549. decoder_outputs, attns, _ = self.decoder(state, tgt[:, :-1], top_vec)
  550. return decoder_outputs, attns[-1], top_vec
  551. class LabelSmoothingLoss(nn.Module):
  552. """
  553. With label smoothing,
  554. KL-divergence between q_{smoothed ground truth prob.}(w)
  555. and p_{prob. computed by model}(w) is minimized.
  556. """
  557. def __init__(self, label_smoothing, tgt_vocab_size, ignore_index=-100):
  558. assert 0.0 < label_smoothing <= 1.0
  559. self.padding_idx = ignore_index
  560. super(LabelSmoothingLoss, self).__init__()
  561. smoothing_value = label_smoothing / (tgt_vocab_size - 2)
  562. one_hot = torch.full((tgt_vocab_size, ), smoothing_value)
  563. one_hot[self.padding_idx] = 0
  564. self.register_buffer('one_hot', one_hot.unsqueeze(0))
  565. self.confidence = 1.0 - label_smoothing
  566. def forward(self, output, target):
  567. """
  568. output (FloatTensor): batch_size x n_classes
  569. target (LongTensor): batch_size
  570. """
  571. model_prob = self.one_hot.repeat(target.size(0), 1)
  572. model_prob.scatter_(1, target.unsqueeze(1), self.confidence)
  573. model_prob.masked_fill_((target == self.padding_idx).unsqueeze(1), 0)
  574. return F.kl_div(output, model_prob, reduction='sum')
  575. class NMTLossCompute(nn.Module):
  576. """
  577. Standard NMT Loss Computation.
  578. """
  579. def __init__(self, generator, symbols, vocab_size, label_smoothing=0.0):
  580. super().__init__()
  581. self.generator = generator
  582. self.padding_idx = symbols['PAD']
  583. if label_smoothing > 0:
  584. self.criterion = LabelSmoothingLoss(
  585. label_smoothing, vocab_size, ignore_index=self.padding_idx)
  586. else:
  587. self.criterion = nn.NLLLoss(
  588. ignore_index=self.padding_idx, reduction='sum')
  589. def _bottle(self, _v):
  590. return _v.view(-1, _v.size(2))
  591. def _unbottle(self, _v, batch_size):
  592. return _v.view(-1, batch_size, _v.size(1))
  593. def forward(self, tgt, output):
  594. target = tgt[:, 1:]
  595. batch_size, decoder_length = target.size(0), target.size(1)
  596. normalization = target.ne(self.padding_idx).sum()
  597. bottled_output = self._bottle(output)
  598. scores = self.generator(bottled_output)
  599. gtruth = target.contiguous().view(-1)
  600. loss = self.criterion(scores, gtruth)
  601. loss = loss.div(float(normalization))
  602. return loss, scores.view(batch_size, decoder_length, -1)
  603. class PlugForConditionalGeneration(PlugPreTrainedModel):
  604. @dataclass
  605. class Batch:
  606. batch_size: int
  607. src: torch.Tensor
  608. tgt: torch.Tensor
  609. mask_src: torch.Tensor
  610. token_type_ids: torch.Tensor
  611. query_id: List[None] = None
  612. src_str: List[List[str]] = None
  613. tgt_str: List[str] = None
  614. def __init__(self, config, checkpoint=None, dataset: str = 'default'):
  615. super().__init__(config)
  616. self.logger = logging.get_logger()
  617. self.config = config
  618. if config.encoder == 'roberta':
  619. tokenizer = RobertaTokenizer.from_pretrained(
  620. config.encoder_pth, do_lower_case=False)
  621. symbols = {
  622. 'BOS': tokenizer.cls_token_id,
  623. 'EOS': tokenizer.sep_token_id,
  624. 'PAD': tokenizer.pad_token_id,
  625. 'EOQ': tokenizer.unk_token_id
  626. }
  627. elif config.encoder == 'bert' or config.encoder == 'zh_bert':
  628. tokenizer = BertTokenizer.from_pretrained(
  629. config.encoder_pth, do_lower_case=True)
  630. symbols = {
  631. 'BOS': tokenizer.vocab['[CLS]'],
  632. 'EOS': tokenizer.vocab['[SEP]'],
  633. 'PAD': tokenizer.vocab['[PAD]'],
  634. 'EOQ': tokenizer.vocab['[unused2]']
  635. }
  636. self.tokenizer = tokenizer
  637. self.symbols = symbols
  638. self.plug = PlugModel(config, checkpoint)
  639. self.loss = NMTLossCompute(self.plug.generator, symbols,
  640. self.plug.vocab_size,
  641. config.label_smoothing)
  642. # for generation
  643. self.config.dataset = dataset
  644. self.start_token = self.symbols['BOS']
  645. self.end_token = self.symbols['EOS']
  646. def forward(self, src, tgt, mask_src=None, token_type_ids=None):
  647. if mask_src is None:
  648. mask_src = src.ne(self.symbols['PAD']).long()
  649. output = self.plug(src, tgt, mask_src, token_type_ids)[0]
  650. loss = self.loss(tgt, output)
  651. return loss
  652. def translate_batch(self,
  653. batch: 'Batch',
  654. fast: bool = False,
  655. *args,
  656. **kwargs):
  657. """
  658. Translate a batch of sentences.
  659. Mostly a wrapper around :obj:`Beam`.
  660. Args:
  661. batch (:obj:`Batch`): a batch from a dataset object
  662. data (:obj:`Dataset`): the dataset object
  663. fast (bool): enables fast beam search (may not support all features)
  664. Todo:
  665. Shouldn't need the original dataset.
  666. """
  667. self.plug.eval()
  668. with torch.no_grad():
  669. return self._fast_translate_batch(batch, *args, **kwargs)
  670. def _tile(self, x, count, dim=0):
  671. perm = list(range(len(x.size())))
  672. if dim != 0:
  673. perm[0], perm[dim] = perm[dim], perm[0]
  674. x = x.permute(perm).contiguous()
  675. out_size = list(x.size())
  676. out_size[0] *= count
  677. batch = x.size(0)
  678. x = x.view(batch, -1) \
  679. .transpose(0, 1) \
  680. .repeat(count, 1) \
  681. .transpose(0, 1) \
  682. .contiguous() \
  683. .view(*out_size)
  684. if dim != 0:
  685. x = x.permute(perm).contiguous()
  686. return x
  687. def _top_k_top_p_filtering(self,
  688. logits,
  689. top_k=10,
  690. top_p=1.0,
  691. filter_value=-float('Inf'),
  692. min_tokens_to_keep=1):
  693. if top_k > 0:
  694. top_k = min(max(top_k, min_tokens_to_keep),
  695. logits.size(-1)) # Safety check
  696. # Remove all tokens with a probability less than the last token of the top-k
  697. indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1,
  698. None]
  699. logits[indices_to_remove] = filter_value
  700. if top_p < 1.0:
  701. sorted_logits, sorted_indices = torch.sort(logits, descending=True)
  702. cumulative_probs = torch.cumsum(
  703. F.softmax(sorted_logits, dim=-1), dim=-1)
  704. # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
  705. sorted_indices_to_remove = cumulative_probs > top_p
  706. if min_tokens_to_keep > 1:
  707. # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
  708. sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
  709. # Shift the indices to the right to keep also the first token above the threshold
  710. sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
  711. ..., :-1].clone()
  712. sorted_indices_to_remove[..., 0] = 0
  713. # scatter sorted tensors to original indexing
  714. indices_to_remove = sorted_indices_to_remove.scatter(
  715. 1, sorted_indices, sorted_indices_to_remove)
  716. logits[indices_to_remove] = filter_value
  717. return logits
  718. def _fast_translate_batch(self,
  719. batch: 'Batch',
  720. max_length: int = 80,
  721. min_length: int = 10,
  722. bad_words_ids=None,
  723. early_stopping=True,
  724. num_beams=3,
  725. length_penalty=1.2,
  726. repetition_penalty=1.2,
  727. no_repeat_ngram_size=4,
  728. do_sample=False,
  729. temperature=1.0,
  730. top_k=0,
  731. top_p=1.0,
  732. *args,
  733. **kwargs):
  734. # TODO: faster code path for beam_size == 1.
  735. # TODO: support these blacklisted features.
  736. num_beams = num_beams
  737. batch_size = batch.batch_size
  738. src = batch.src
  739. mask_src = batch.mask_src
  740. token_type_ids = batch.token_type_ids
  741. src_features, _ = self.plug.bert(
  742. src, mask_src, token_type_ids=token_type_ids, return_dict=False)
  743. state = TransformerDecoderState(src, self.plug.decoder.num_layers)
  744. device = src_features.device
  745. # Tile states and memory beam_size times.
  746. state.map_batch_fn(
  747. lambda state, dim: self._tile(state, num_beams, dim=dim))
  748. src_features = self._tile(src_features, num_beams, dim=0)
  749. batch_offset = torch.arange(
  750. batch_size, dtype=torch.long, device=device)
  751. beam_offset = torch.arange(
  752. 0,
  753. batch_size * num_beams,
  754. step=num_beams,
  755. dtype=torch.long,
  756. device=device)
  757. alive_seq = torch.full([batch_size * num_beams, 1],
  758. self.start_token,
  759. dtype=torch.long,
  760. device=device)
  761. # cal bad_words_ids pre dict
  762. bad_words_prefix_dict = {}
  763. bad_words_prefix_len = set([])
  764. if bad_words_ids is not None:
  765. for bw_id in bad_words_ids:
  766. key = tuple(bw_id[:-1])
  767. value = bw_id[-1]
  768. bad_words_prefix_dict[key] = bad_words_prefix_dict.get(
  769. key, []) + [value]
  770. bad_words_prefix_len.add(len(key))
  771. # Give full probability to the first beam on the first step.
  772. topk_log_probs = (
  773. torch.tensor(
  774. [0.0] + [float('-inf')] * (num_beams - 1),
  775. device=device).repeat(batch_size))
  776. # Structure that holds finished hypotheses.
  777. hypotheses = [[] for _ in range(batch_size)] # noqa: F812
  778. results = {}
  779. results['predictions'] = [[] for _ in range(batch_size)] # noqa: F812
  780. results['scores'] = [[] for _ in range(batch_size)] # noqa: F812
  781. results['gold_score'] = [0] * batch_size
  782. results['batch'] = batch
  783. for step in range(max_length):
  784. # self.logger.info(f'step: {step + 1} / {max_length}')
  785. decoder_input = alive_seq[:, -1].view(1, -1)
  786. # Decoder forward.
  787. decoder_input = decoder_input.transpose(0, 1)
  788. dec_out, attns, state = self.plug.decoder(
  789. state, decoder_input, src_features, step=step)
  790. # Generator forward.
  791. log_probs = self.plug.generator.forward(
  792. dec_out.transpose(0, 1).squeeze(0))
  793. vocab_size = log_probs.size(-1)
  794. if step < min_length:
  795. log_probs[:, self.end_token] = -1e20
  796. # filter bad word
  797. if len(bad_words_prefix_dict) > 0:
  798. # cal bad word banned token: batch_size * num_beams
  799. num_hypos = alive_seq.size(0)
  800. bad_word_banned_token = []
  801. for i in range(num_hypos):
  802. curr_banned_token = []
  803. for pre_len in bad_words_prefix_len:
  804. pre_key = tuple(alive_seq[i, step + 1 - pre_len:step
  805. + 1].cpu().numpy().tolist())
  806. curr_banned_token += bad_words_prefix_dict.get(
  807. pre_key, [])
  808. bad_word_banned_token.append(set(curr_banned_token))
  809. # set banned word prob=-1e20
  810. assert log_probs.size(0) == num_hypos
  811. for i in range(num_hypos):
  812. for banned_token in bad_word_banned_token[i]:
  813. log_probs[i, banned_token] = -1e20
  814. # do repetition_penalty
  815. if repetition_penalty > 1.0:
  816. """repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858). """
  817. # calculate prev_output_tokens for repetition_penalty: batch_size * num_beams
  818. prev_output_tokens = self.calc_banned_tokens(
  819. alive_seq, alive_seq.size(0), no_repeat_ngram_size,
  820. step + 1)
  821. # batch_size * num_beams
  822. for i in range(log_probs.size(0)):
  823. for previous_token in set(prev_output_tokens[i]):
  824. if log_probs[i, previous_token] < 0:
  825. log_probs[i, previous_token] *= repetition_penalty
  826. else:
  827. log_probs[i, previous_token] /= repetition_penalty
  828. # Multiply probs by the beam probability.
  829. curr_length_penalty = (step + 1)**length_penalty
  830. # '''
  831. if do_sample:
  832. _scores = log_probs / temperature
  833. _scores = self._top_k_top_p_filtering(
  834. _scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=1
  835. ) # (batch_size * num_beams, vocab_size)
  836. # Sample 2 next words for each beam (so we have some spare tokens
  837. # and match output of greedy beam search)
  838. topk_ids = torch.multinomial(
  839. F.softmax(_scores, dim=-1),
  840. num_samples=1) # (batch_size * num_beams, 2)
  841. # Compute next scores
  842. _scores = F.log_softmax(
  843. _scores, dim=1) # (batch_size * num_beams, vocab_size)
  844. _scores += topk_log_probs.view(-1).unsqueeze(1)
  845. _scores = _scores / curr_length_penalty
  846. topk_scores = torch.gather(
  847. _scores, -1, topk_ids) # (batch_size * num_beams, 2)
  848. # log_probs += # (batch_size * num_beams, 2)
  849. # Match shape of greedy beam search
  850. topk_ids = topk_ids.view(
  851. -1, num_beams) # (batch_size, 2 * num_beams)
  852. topk_scores = topk_scores.view(
  853. -1, num_beams) # (batch_size, 2 * num_beams)
  854. # '''
  855. else:
  856. log_probs += topk_log_probs.view(-1).unsqueeze(1)
  857. curr_scores = log_probs / curr_length_penalty
  858. curr_scores = curr_scores.reshape(-1, num_beams * vocab_size)
  859. topk_scores, topk_ids = curr_scores.topk(num_beams, dim=-1)
  860. if (self.config.block_trigram):
  861. cur_len = alive_seq.size(1)
  862. if (cur_len > 3):
  863. for i in range(alive_seq.size(0)):
  864. fail = False
  865. words = [int(w) for w in alive_seq[i]]
  866. if self.config.encoder == 'roberta':
  867. # words = [self.vocab.convert_ids_to_tokens[w] for w in words]
  868. words = self.tokenizer.decode(
  869. words).strip().split()
  870. else:
  871. words = [
  872. self.tokenizer.ids_to_tokens[w] for w in words
  873. ]
  874. words = ' '.join(words).replace(' ##', '').split()
  875. if (len(words) <= 3):
  876. continue
  877. trigrams = [(words[i - 1], words[i], words[i + 1])
  878. for i in range(1,
  879. len(words) - 1)]
  880. trigram = tuple(trigrams[-1])
  881. if trigram in trigrams[:-1]:
  882. fail = True
  883. if fail:
  884. curr_scores[i] = -10e20
  885. # Recover log probs.
  886. topk_log_probs = topk_scores * curr_length_penalty
  887. # Resolve beam origin and true word ids.
  888. # topk_beam_index = topk_ids.div(vocab_size)
  889. topk_beam_index = topk_ids // vocab_size
  890. topk_ids = topk_ids.fmod(vocab_size)
  891. # Map beam_index to batch_index in the flat representation.
  892. batch_index = (
  893. topk_beam_index
  894. + beam_offset[:topk_beam_index.size(0)].unsqueeze(1))
  895. select_indices = batch_index.view(-1)
  896. # Append last prediction.
  897. alive_seq = torch.cat([
  898. alive_seq.index_select(0, select_indices),
  899. topk_ids.view(-1, 1)
  900. ], -1)
  901. is_finished = topk_ids.eq(self.end_token)
  902. if step + 1 == max_length:
  903. is_finished.fill_(self.end_token)
  904. # End condition is top beam is finished.
  905. end_condition = is_finished[:, 0].eq(1)
  906. # Save finished hypotheses.
  907. if is_finished.any():
  908. predictions = alive_seq.view(-1, num_beams, alive_seq.size(-1))
  909. for i in range(is_finished.size(0)):
  910. b = batch_offset[i]
  911. if end_condition[i]:
  912. is_finished[i].fill_(self.end_token)
  913. finished_hyp = is_finished[i].nonzero().view(-1)
  914. # Store finished hypotheses for this batch.
  915. for j in finished_hyp:
  916. hypotheses[b].append(
  917. (topk_scores[i, j], predictions[i, j, 1:]))
  918. if early_stopping and len(hypotheses) == num_beams:
  919. end_condition[i] = True
  920. # If the batch reached the end, save the n_best hypotheses.
  921. if end_condition[i]:
  922. best_hyp = sorted(
  923. hypotheses[b], key=lambda x: x[0], reverse=True)
  924. if self.config.dataset == 'qg_ranking_test' or (
  925. self.config.dataset == 'paraphrase'
  926. and not self.config.sample_topk):
  927. for each in best_hyp[:num_beams]:
  928. score, pred = each
  929. results['scores'][b].append(score)
  930. results['predictions'][b].append(pred)
  931. else:
  932. score, pred = best_hyp[0]
  933. results['scores'][b].append(score)
  934. results['predictions'][b].append(pred)
  935. non_finished = end_condition.eq(0).nonzero().view(-1)
  936. # If all sentences are translated, no need to go further.
  937. if len(non_finished) == 0:
  938. break
  939. # Remove finished batches for the next step.
  940. topk_log_probs = topk_log_probs.index_select(0, non_finished)
  941. batch_index = batch_index.index_select(0, non_finished)
  942. batch_offset = batch_offset.index_select(0, non_finished)
  943. alive_seq = predictions.index_select(0, non_finished) \
  944. .view(-1, alive_seq.size(-1))
  945. # Reorder states.
  946. select_indices = batch_index.view(-1)
  947. src_features = src_features.index_select(0, select_indices)
  948. state.map_batch_fn(
  949. lambda state, dim: state.index_select(dim, select_indices))
  950. return results
  951. def calc_banned_tokens(self, prev_input_ids, num_hypos,
  952. no_repeat_ngram_size, cur_len):
  953. # Copied from fairseq for no_repeat_ngram in beam_search"""
  954. if cur_len + 1 < no_repeat_ngram_size:
  955. # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
  956. return [[] for _ in range(num_hypos)]
  957. generated_ngrams = [{} for _ in range(num_hypos)]
  958. for idx in range(num_hypos):
  959. gen_tokens = prev_input_ids[idx].cpu().numpy().tolist()
  960. generated_ngram = generated_ngrams[idx]
  961. for ngram in zip(
  962. *[gen_tokens[i:] for i in range(no_repeat_ngram_size)]):
  963. prev_ngram_tuple = tuple(ngram[:-1])
  964. generated_ngram[prev_ngram_tuple] = generated_ngram.get(
  965. prev_ngram_tuple, []) + [ngram[-1]]
  966. def _get_generated_ngrams(hypo_idx):
  967. # Before decoding the next token, prevent decoding of ngrams that have already appeared
  968. start_idx = cur_len + 1 - no_repeat_ngram_size
  969. ngram_idx = tuple(
  970. prev_input_ids[hypo_idx,
  971. start_idx:cur_len].cpu().numpy().tolist())
  972. return generated_ngrams[hypo_idx].get(ngram_idx, [])
  973. banned_tokens = [
  974. _get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)
  975. ]
  976. return banned_tokens
  977. def translate(self,
  978. input_ids: torch.Tensor,
  979. attention_mask: torch.Tensor = None,
  980. token_type_ids=None,
  981. *args,
  982. **kwargs) -> Dict[str, torch.Tensor]:
  983. if attention_mask is None:
  984. attention_mask = input_ids.ne(self.symbols['PAD']).long()
  985. batch = self.Batch(
  986. batch_size=input_ids.size()[0],
  987. src=input_ids,
  988. tgt=None,
  989. token_type_ids=token_type_ids,
  990. mask_src=attention_mask)
  991. translation_batch = self.translate_batch(batch, *args, **kwargs)
  992. preds = translation_batch['predictions']
  993. return {'predictions': preds}