self_attention.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import math
  18. import paddle
  19. from paddle import ParamAttr, nn
  20. from paddle import nn, ParamAttr
  21. from paddle.nn import functional as F
  22. import numpy as np
  23. gradient_clip = 10
  24. class WrapEncoderForFeature(nn.Layer):
  25. def __init__(
  26. self,
  27. src_vocab_size,
  28. max_length,
  29. n_layer,
  30. n_head,
  31. d_key,
  32. d_value,
  33. d_model,
  34. d_inner_hid,
  35. prepostprocess_dropout,
  36. attention_dropout,
  37. relu_dropout,
  38. preprocess_cmd,
  39. postprocess_cmd,
  40. weight_sharing,
  41. bos_idx=0,
  42. ):
  43. super(WrapEncoderForFeature, self).__init__()
  44. self.prepare_encoder = PrepareEncoder(
  45. src_vocab_size,
  46. d_model,
  47. max_length,
  48. prepostprocess_dropout,
  49. bos_idx=bos_idx,
  50. word_emb_param_name="src_word_emb_table",
  51. )
  52. self.encoder = Encoder(
  53. n_layer,
  54. n_head,
  55. d_key,
  56. d_value,
  57. d_model,
  58. d_inner_hid,
  59. prepostprocess_dropout,
  60. attention_dropout,
  61. relu_dropout,
  62. preprocess_cmd,
  63. postprocess_cmd,
  64. )
  65. def forward(self, enc_inputs):
  66. conv_features, src_pos, src_slf_attn_bias = enc_inputs
  67. enc_input = self.prepare_encoder(conv_features, src_pos)
  68. enc_output = self.encoder(enc_input, src_slf_attn_bias)
  69. return enc_output
  70. class WrapEncoder(nn.Layer):
  71. """
  72. embedder + encoder
  73. """
  74. def __init__(
  75. self,
  76. src_vocab_size,
  77. max_length,
  78. n_layer,
  79. n_head,
  80. d_key,
  81. d_value,
  82. d_model,
  83. d_inner_hid,
  84. prepostprocess_dropout,
  85. attention_dropout,
  86. relu_dropout,
  87. preprocess_cmd,
  88. postprocess_cmd,
  89. weight_sharing,
  90. bos_idx=0,
  91. ):
  92. super(WrapEncoder, self).__init__()
  93. self.prepare_decoder = PrepareDecoder(
  94. src_vocab_size, d_model, max_length, prepostprocess_dropout, bos_idx=bos_idx
  95. )
  96. self.encoder = Encoder(
  97. n_layer,
  98. n_head,
  99. d_key,
  100. d_value,
  101. d_model,
  102. d_inner_hid,
  103. prepostprocess_dropout,
  104. attention_dropout,
  105. relu_dropout,
  106. preprocess_cmd,
  107. postprocess_cmd,
  108. )
  109. def forward(self, enc_inputs):
  110. src_word, src_pos, src_slf_attn_bias = enc_inputs
  111. enc_input = self.prepare_decoder(src_word, src_pos)
  112. enc_output = self.encoder(enc_input, src_slf_attn_bias)
  113. return enc_output
  114. class Encoder(nn.Layer):
  115. """
  116. encoder
  117. """
  118. def __init__(
  119. self,
  120. n_layer,
  121. n_head,
  122. d_key,
  123. d_value,
  124. d_model,
  125. d_inner_hid,
  126. prepostprocess_dropout,
  127. attention_dropout,
  128. relu_dropout,
  129. preprocess_cmd="n",
  130. postprocess_cmd="da",
  131. ):
  132. super(Encoder, self).__init__()
  133. self.encoder_layers = list()
  134. for i in range(n_layer):
  135. self.encoder_layers.append(
  136. self.add_sublayer(
  137. "layer_%d" % i,
  138. EncoderLayer(
  139. n_head,
  140. d_key,
  141. d_value,
  142. d_model,
  143. d_inner_hid,
  144. prepostprocess_dropout,
  145. attention_dropout,
  146. relu_dropout,
  147. preprocess_cmd,
  148. postprocess_cmd,
  149. ),
  150. )
  151. )
  152. self.processor = PrePostProcessLayer(
  153. preprocess_cmd, d_model, prepostprocess_dropout
  154. )
  155. def forward(self, enc_input, attn_bias):
  156. for encoder_layer in self.encoder_layers:
  157. enc_output = encoder_layer(enc_input, attn_bias)
  158. enc_input = enc_output
  159. enc_output = self.processor(enc_output)
  160. return enc_output
  161. class EncoderLayer(nn.Layer):
  162. """
  163. EncoderLayer
  164. """
  165. def __init__(
  166. self,
  167. n_head,
  168. d_key,
  169. d_value,
  170. d_model,
  171. d_inner_hid,
  172. prepostprocess_dropout,
  173. attention_dropout,
  174. relu_dropout,
  175. preprocess_cmd="n",
  176. postprocess_cmd="da",
  177. ):
  178. super(EncoderLayer, self).__init__()
  179. self.preprocesser1 = PrePostProcessLayer(
  180. preprocess_cmd, d_model, prepostprocess_dropout
  181. )
  182. self.self_attn = MultiHeadAttention(
  183. d_key, d_value, d_model, n_head, attention_dropout
  184. )
  185. self.postprocesser1 = PrePostProcessLayer(
  186. postprocess_cmd, d_model, prepostprocess_dropout
  187. )
  188. self.preprocesser2 = PrePostProcessLayer(
  189. preprocess_cmd, d_model, prepostprocess_dropout
  190. )
  191. self.ffn = FFN(d_inner_hid, d_model, relu_dropout)
  192. self.postprocesser2 = PrePostProcessLayer(
  193. postprocess_cmd, d_model, prepostprocess_dropout
  194. )
  195. def forward(self, enc_input, attn_bias):
  196. attn_output = self.self_attn(
  197. self.preprocesser1(enc_input), None, None, attn_bias
  198. )
  199. attn_output = self.postprocesser1(attn_output, enc_input)
  200. ffn_output = self.ffn(self.preprocesser2(attn_output))
  201. ffn_output = self.postprocesser2(ffn_output, attn_output)
  202. return ffn_output
  203. class MultiHeadAttention(nn.Layer):
  204. """
  205. Multi-Head Attention
  206. """
  207. def __init__(self, d_key, d_value, d_model, n_head=1, dropout_rate=0.0):
  208. super(MultiHeadAttention, self).__init__()
  209. self.n_head = n_head
  210. self.d_key = d_key
  211. self.d_value = d_value
  212. self.d_model = d_model
  213. self.dropout_rate = dropout_rate
  214. self.q_fc = paddle.nn.Linear(
  215. in_features=d_model, out_features=d_key * n_head, bias_attr=False
  216. )
  217. self.k_fc = paddle.nn.Linear(
  218. in_features=d_model, out_features=d_key * n_head, bias_attr=False
  219. )
  220. self.v_fc = paddle.nn.Linear(
  221. in_features=d_model, out_features=d_value * n_head, bias_attr=False
  222. )
  223. self.proj_fc = paddle.nn.Linear(
  224. in_features=d_value * n_head, out_features=d_model, bias_attr=False
  225. )
  226. def _prepare_qkv(self, queries, keys, values, cache=None):
  227. if keys is None: # self-attention
  228. keys, values = queries, queries
  229. static_kv = False
  230. else: # cross-attention
  231. static_kv = True
  232. q = self.q_fc(queries)
  233. q = paddle.reshape(x=q, shape=[0, 0, self.n_head, self.d_key])
  234. q = paddle.transpose(x=q, perm=[0, 2, 1, 3])
  235. if cache is not None and static_kv and "static_k" in cache:
  236. # for encoder-decoder attention in inference and has cached
  237. k = cache["static_k"]
  238. v = cache["static_v"]
  239. else:
  240. k = self.k_fc(keys)
  241. v = self.v_fc(values)
  242. k = paddle.reshape(x=k, shape=[0, 0, self.n_head, self.d_key])
  243. k = paddle.transpose(x=k, perm=[0, 2, 1, 3])
  244. v = paddle.reshape(x=v, shape=[0, 0, self.n_head, self.d_value])
  245. v = paddle.transpose(x=v, perm=[0, 2, 1, 3])
  246. if cache is not None:
  247. if static_kv and not "static_k" in cache:
  248. # for encoder-decoder attention in inference and has not cached
  249. cache["static_k"], cache["static_v"] = k, v
  250. elif not static_kv:
  251. # for decoder self-attention in inference
  252. cache_k, cache_v = cache["k"], cache["v"]
  253. k = paddle.concat([cache_k, k], axis=2)
  254. v = paddle.concat([cache_v, v], axis=2)
  255. cache["k"], cache["v"] = k, v
  256. return q, k, v
  257. def forward(self, queries, keys, values, attn_bias, cache=None):
  258. # compute q ,k ,v
  259. keys = queries if keys is None else keys
  260. values = keys if values is None else values
  261. q, k, v = self._prepare_qkv(queries, keys, values, cache)
  262. # scale dot product attention
  263. product = paddle.matmul(x=q, y=k, transpose_y=True)
  264. product = product * self.d_model**-0.5
  265. if attn_bias is not None:
  266. product += attn_bias.astype(product.dtype)
  267. weights = F.softmax(product)
  268. if self.dropout_rate:
  269. weights = F.dropout(weights, p=self.dropout_rate, mode="downscale_in_infer")
  270. out = paddle.matmul(weights, v)
  271. # combine heads
  272. out = paddle.transpose(out, perm=[0, 2, 1, 3])
  273. out = paddle.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])
  274. # project to output
  275. out = self.proj_fc(out)
  276. return out
  277. class PrePostProcessLayer(nn.Layer):
  278. """
  279. PrePostProcessLayer
  280. """
  281. def __init__(self, process_cmd, d_model, dropout_rate):
  282. super(PrePostProcessLayer, self).__init__()
  283. self.process_cmd = process_cmd
  284. self.functors = []
  285. for cmd in self.process_cmd:
  286. if cmd == "a": # add residual connection
  287. self.functors.append(lambda x, y: x + y if y is not None else x)
  288. elif cmd == "n": # add layer normalization
  289. self.functors.append(
  290. self.add_sublayer(
  291. "layer_norm_%d" % len(self.sublayers()),
  292. paddle.nn.LayerNorm(
  293. normalized_shape=d_model,
  294. weight_attr=paddle.ParamAttr(
  295. initializer=paddle.nn.initializer.Constant(1.0)
  296. ),
  297. bias_attr=paddle.ParamAttr(
  298. initializer=paddle.nn.initializer.Constant(0.0)
  299. ),
  300. ),
  301. )
  302. )
  303. elif cmd == "d": # add dropout
  304. self.functors.append(
  305. lambda x: (
  306. F.dropout(x, p=dropout_rate, mode="downscale_in_infer")
  307. if dropout_rate
  308. else x
  309. )
  310. )
  311. def forward(self, x, residual=None):
  312. for i, cmd in enumerate(self.process_cmd):
  313. if cmd == "a":
  314. x = self.functors[i](x, residual)
  315. else:
  316. x = self.functors[i](x)
  317. return x
  318. class PrepareEncoder(nn.Layer):
  319. def __init__(
  320. self,
  321. src_vocab_size,
  322. src_emb_dim,
  323. src_max_len,
  324. dropout_rate=0,
  325. bos_idx=0,
  326. word_emb_param_name=None,
  327. pos_enc_param_name=None,
  328. ):
  329. super(PrepareEncoder, self).__init__()
  330. self.src_emb_dim = src_emb_dim
  331. self.src_max_len = src_max_len
  332. self.emb = paddle.nn.Embedding(
  333. num_embeddings=self.src_max_len, embedding_dim=self.src_emb_dim
  334. )
  335. self.dropout_rate = dropout_rate
  336. def forward(self, src_word, src_pos):
  337. src_word_emb = src_word
  338. src_word_emb = paddle.cast(src_word_emb, "float32")
  339. src_word_emb = paddle.scale(x=src_word_emb, scale=self.src_emb_dim**0.5)
  340. src_pos = paddle.squeeze(src_pos, axis=-1)
  341. src_pos_enc = self.emb(src_pos)
  342. src_pos_enc.stop_gradient = True
  343. enc_input = src_word_emb + src_pos_enc
  344. if self.dropout_rate:
  345. out = F.dropout(x=enc_input, p=self.dropout_rate, mode="downscale_in_infer")
  346. else:
  347. out = enc_input
  348. return out
  349. class PrepareDecoder(nn.Layer):
  350. def __init__(
  351. self,
  352. src_vocab_size,
  353. src_emb_dim,
  354. src_max_len,
  355. dropout_rate=0,
  356. bos_idx=0,
  357. word_emb_param_name=None,
  358. pos_enc_param_name=None,
  359. ):
  360. super(PrepareDecoder, self).__init__()
  361. self.src_emb_dim = src_emb_dim
  362. """
  363. self.emb0 = Embedding(num_embeddings=src_vocab_size,
  364. embedding_dim=src_emb_dim)
  365. """
  366. self.emb0 = paddle.nn.Embedding(
  367. num_embeddings=src_vocab_size,
  368. embedding_dim=self.src_emb_dim,
  369. padding_idx=bos_idx,
  370. weight_attr=paddle.ParamAttr(
  371. name=word_emb_param_name,
  372. initializer=nn.initializer.Normal(0.0, src_emb_dim**-0.5),
  373. ),
  374. )
  375. self.emb1 = paddle.nn.Embedding(
  376. num_embeddings=src_max_len,
  377. embedding_dim=self.src_emb_dim,
  378. weight_attr=paddle.ParamAttr(name=pos_enc_param_name),
  379. )
  380. self.dropout_rate = dropout_rate
  381. def forward(self, src_word, src_pos):
  382. src_word = paddle.cast(src_word, "int64")
  383. src_word = paddle.squeeze(src_word, axis=-1)
  384. src_word_emb = self.emb0(src_word)
  385. src_word_emb = paddle.scale(x=src_word_emb, scale=self.src_emb_dim**0.5)
  386. src_pos = paddle.squeeze(src_pos, axis=-1)
  387. src_pos_enc = self.emb1(src_pos)
  388. src_pos_enc.stop_gradient = True
  389. enc_input = src_word_emb + src_pos_enc
  390. if self.dropout_rate:
  391. out = F.dropout(x=enc_input, p=self.dropout_rate, mode="downscale_in_infer")
  392. else:
  393. out = enc_input
  394. return out
  395. class FFN(nn.Layer):
  396. """
  397. Feed-Forward Network
  398. """
  399. def __init__(self, d_inner_hid, d_model, dropout_rate):
  400. super(FFN, self).__init__()
  401. self.dropout_rate = dropout_rate
  402. self.fc1 = paddle.nn.Linear(in_features=d_model, out_features=d_inner_hid)
  403. self.fc2 = paddle.nn.Linear(in_features=d_inner_hid, out_features=d_model)
  404. def forward(self, x):
  405. hidden = self.fc1(x)
  406. hidden = F.relu(hidden)
  407. if self.dropout_rate:
  408. hidden = F.dropout(hidden, p=self.dropout_rate, mode="downscale_in_infer")
  409. out = self.fc2(hidden)
  410. return out