transformer.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import math
  3. import torch
  4. import torch.nn as nn
  5. def gelu(x):
  6. return 0.5 * x * (1 + torch.tanh(
  7. math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
  8. class PositionwiseFeedForward(nn.Module):
  9. def __init__(self, d_model, d_ff, dropout=0.1):
  10. super(PositionwiseFeedForward, self).__init__()
  11. self.w_1 = nn.Linear(d_model, d_ff)
  12. self.w_2 = nn.Linear(d_ff, d_model)
  13. self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
  14. self.actv = gelu
  15. self.dropout_1 = nn.Dropout(dropout)
  16. self.dropout_2 = nn.Dropout(dropout)
  17. def forward(self, x):
  18. inter = self.dropout_1(self.actv(self.w_1(self.layer_norm(x))))
  19. output = self.dropout_2(self.w_2(inter))
  20. return output + x
  21. class MultiHeadedAttention(nn.Module):
  22. def __init__(self, head_count, model_dim, dropout=0.1):
  23. assert model_dim % head_count == 0
  24. self.dim_per_head = model_dim // head_count
  25. self.model_dim = model_dim
  26. super(MultiHeadedAttention, self).__init__()
  27. self.head_count = head_count
  28. self.linear_k = nn.Linear(model_dim, head_count * self.dim_per_head)
  29. self.linear_v = nn.Linear(model_dim, head_count * self.dim_per_head)
  30. self.linear_q = nn.Linear(model_dim, head_count * self.dim_per_head)
  31. self.softmax = nn.Softmax(dim=-1)
  32. self.dropout = nn.Dropout(dropout)
  33. self.linear = nn.Linear(model_dim, model_dim)
  34. def forward(self, key, value, query, mask=None):
  35. batch_size = key.size(0)
  36. dim_per_head = self.dim_per_head
  37. head_count = self.head_count
  38. def shape(x):
  39. """ projection """
  40. return x.view(batch_size, -1, head_count, dim_per_head) \
  41. .transpose(1, 2)
  42. def unshape(x):
  43. """ compute context """
  44. return x.transpose(1, 2).contiguous() \
  45. .view(batch_size, -1, head_count * dim_per_head)
  46. key = self.linear_k(key).view(batch_size, -1, head_count,
  47. dim_per_head).transpose(1, 2)
  48. value = self.linear_v(value).view(batch_size, -1, head_count,
  49. dim_per_head).transpose(1, 2)
  50. query = self.linear_q(query).view(batch_size, -1, head_count,
  51. dim_per_head).transpose(1, 2)
  52. query = query / math.sqrt(dim_per_head)
  53. scores = torch.matmul(query, key.transpose(2, 3))
  54. if mask is not None:
  55. mask = mask.unsqueeze(1).expand_as(scores)
  56. scores = scores.masked_fill(mask, -1e10)
  57. attn = self.softmax(scores)
  58. drop_attn = self.dropout(attn)
  59. context = torch.matmul(drop_attn,
  60. value).transpose(1, 2).contiguous().view(
  61. batch_size, -1, head_count * dim_per_head)
  62. output = self.linear(context)
  63. return output
  64. class PositionalEncoding(nn.Module):
  65. def __init__(self, dim, max_len=512):
  66. super(PositionalEncoding, self).__init__()
  67. pe = torch.zeros(max_len, dim)
  68. position = torch.arange(0, max_len).unsqueeze(1)
  69. div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.float)
  70. * -(math.log(10000.0) / dim)))
  71. pe[:, 0::2] = torch.sin(position.float() * div_term)
  72. pe[:, 1::2] = torch.cos(position.float() * div_term)
  73. pe = pe.unsqueeze(0)
  74. self.register_buffer('pe', pe)
  75. def forward(self, x):
  76. L = x.size(1)
  77. pos_emb = self.pe[:, :L]
  78. x = x + pos_emb
  79. return x
  80. class TransformerEncoderLayer(nn.Module):
  81. def __init__(self, d_model, heads, d_ff, dropout):
  82. super(TransformerEncoderLayer, self).__init__()
  83. self.self_attn = MultiHeadedAttention(heads, d_model, dropout=dropout)
  84. self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
  85. self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
  86. self.dropout = nn.Dropout(dropout)
  87. def forward(self, iter, query, inputs, mask):
  88. if iter != 0:
  89. input_norm = self.layer_norm(inputs)
  90. else:
  91. input_norm = inputs
  92. mask = mask.unsqueeze(1)
  93. context = self.self_attn(input_norm, input_norm, input_norm, mask=mask)
  94. out = self.dropout(context) + inputs
  95. return self.feed_forward(out)
  96. class TransformerEncoder(nn.Module):
  97. def __init__(self, d_model, d_ff, heads, layers, dropout=0.1):
  98. super(TransformerEncoder, self).__init__()
  99. self.d_model = d_model
  100. self.layers = layers
  101. self.pos_emb = PositionalEncoding(d_model)
  102. self.transformer_inter = nn.ModuleList([
  103. TransformerEncoderLayer(d_model, heads, d_ff, dropout)
  104. for _ in range(layers)
  105. ])
  106. self.dropout = nn.Dropout(dropout)
  107. def forward(self, x, mask):
  108. x = self.pos_emb(x)
  109. x = self.dropout(x)
  110. for i in range(self.layers):
  111. x = self.transformer_inter[i](i, x, x, mask.eq(0))
  112. return x