| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- import math
- import torch
- import torch.nn as nn
- def gelu(x):
- return 0.5 * x * (1 + torch.tanh(
- math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
- class PositionwiseFeedForward(nn.Module):
- def __init__(self, d_model, d_ff, dropout=0.1):
- super(PositionwiseFeedForward, self).__init__()
- self.w_1 = nn.Linear(d_model, d_ff)
- self.w_2 = nn.Linear(d_ff, d_model)
- self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
- self.actv = gelu
- self.dropout_1 = nn.Dropout(dropout)
- self.dropout_2 = nn.Dropout(dropout)
- def forward(self, x):
- inter = self.dropout_1(self.actv(self.w_1(self.layer_norm(x))))
- output = self.dropout_2(self.w_2(inter))
- return output + x
- class MultiHeadedAttention(nn.Module):
- def __init__(self, head_count, model_dim, dropout=0.1):
- assert model_dim % head_count == 0
- self.dim_per_head = model_dim // head_count
- self.model_dim = model_dim
- super(MultiHeadedAttention, self).__init__()
- self.head_count = head_count
- self.linear_k = nn.Linear(model_dim, head_count * self.dim_per_head)
- self.linear_v = nn.Linear(model_dim, head_count * self.dim_per_head)
- self.linear_q = nn.Linear(model_dim, head_count * self.dim_per_head)
- self.softmax = nn.Softmax(dim=-1)
- self.dropout = nn.Dropout(dropout)
- self.linear = nn.Linear(model_dim, model_dim)
- def forward(self, key, value, query, mask=None):
- batch_size = key.size(0)
- dim_per_head = self.dim_per_head
- head_count = self.head_count
- def shape(x):
- """ projection """
- return x.view(batch_size, -1, head_count, dim_per_head) \
- .transpose(1, 2)
- def unshape(x):
- """ compute context """
- return x.transpose(1, 2).contiguous() \
- .view(batch_size, -1, head_count * dim_per_head)
- key = self.linear_k(key).view(batch_size, -1, head_count,
- dim_per_head).transpose(1, 2)
- value = self.linear_v(value).view(batch_size, -1, head_count,
- dim_per_head).transpose(1, 2)
- query = self.linear_q(query).view(batch_size, -1, head_count,
- dim_per_head).transpose(1, 2)
- query = query / math.sqrt(dim_per_head)
- scores = torch.matmul(query, key.transpose(2, 3))
- if mask is not None:
- mask = mask.unsqueeze(1).expand_as(scores)
- scores = scores.masked_fill(mask, -1e10)
- attn = self.softmax(scores)
- drop_attn = self.dropout(attn)
- context = torch.matmul(drop_attn,
- value).transpose(1, 2).contiguous().view(
- batch_size, -1, head_count * dim_per_head)
- output = self.linear(context)
- return output
- class PositionalEncoding(nn.Module):
- def __init__(self, dim, max_len=512):
- super(PositionalEncoding, self).__init__()
- pe = torch.zeros(max_len, dim)
- position = torch.arange(0, max_len).unsqueeze(1)
- div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.float)
- * -(math.log(10000.0) / dim)))
- pe[:, 0::2] = torch.sin(position.float() * div_term)
- pe[:, 1::2] = torch.cos(position.float() * div_term)
- pe = pe.unsqueeze(0)
- self.register_buffer('pe', pe)
- def forward(self, x):
- L = x.size(1)
- pos_emb = self.pe[:, :L]
- x = x + pos_emb
- return x
- class TransformerEncoderLayer(nn.Module):
- def __init__(self, d_model, heads, d_ff, dropout):
- super(TransformerEncoderLayer, self).__init__()
- self.self_attn = MultiHeadedAttention(heads, d_model, dropout=dropout)
- self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
- self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
- self.dropout = nn.Dropout(dropout)
- def forward(self, iter, query, inputs, mask):
- if iter != 0:
- input_norm = self.layer_norm(inputs)
- else:
- input_norm = inputs
- mask = mask.unsqueeze(1)
- context = self.self_attn(input_norm, input_norm, input_norm, mask=mask)
- out = self.dropout(context) + inputs
- return self.feed_forward(out)
- class TransformerEncoder(nn.Module):
- def __init__(self, d_model, d_ff, heads, layers, dropout=0.1):
- super(TransformerEncoder, self).__init__()
- self.d_model = d_model
- self.layers = layers
- self.pos_emb = PositionalEncoding(d_model)
- self.transformer_inter = nn.ModuleList([
- TransformerEncoderLayer(d_model, heads, d_ff, dropout)
- for _ in range(layers)
- ])
- self.dropout = nn.Dropout(dropout)
- def forward(self, x, mask):
- x = self.pos_emb(x)
- x = self.dropout(x)
- for i in range(self.layers):
- x = self.transformer_inter[i](i, x, x, mask.eq(0))
- return x
|