embedder.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import torch
  3. import torch.nn as nn
  4. class Embedder(nn.Module):
  5. """
  6. Composite embedding layer.
  7. """
  8. def __init__(self,
  9. hidden_dim,
  10. num_token_embeddings,
  11. num_pos_embeddings,
  12. num_type_embeddings,
  13. num_turn_embeddings,
  14. padding_idx=None,
  15. dropout=0.1,
  16. pos_trainable=False):
  17. super(Embedder, self).__init__()
  18. self.token_embedding = nn.Embedding(num_token_embeddings, hidden_dim)
  19. self.pos_embedding = nn.Embedding(num_pos_embeddings, hidden_dim)
  20. self.pos_embedding.weight.requires_grad = pos_trainable
  21. self.type_embedding = nn.Embedding(num_type_embeddings, hidden_dim)
  22. self.turn_embedding = nn.Embedding(num_turn_embeddings, hidden_dim)
  23. self.dropout_layer = nn.Dropout(p=dropout)
  24. # follow the default xavier_uniform initializer in paddle version
  25. # otherwise, there are bugs for dec_probs computation in weight typing setting
  26. # default norm initializer in nn.Embedding in pytorch, which samples larger values
  27. nn.init.xavier_uniform_(self.token_embedding.weight)
  28. nn.init.xavier_uniform_(self.pos_embedding.weight)
  29. nn.init.xavier_uniform_(self.type_embedding.weight)
  30. nn.init.xavier_uniform_(self.turn_embedding.weight)
  31. return
  32. def forward(self, token_inp, pos_inp=None, type_inp=None, turn_inp=None):
  33. embed = self.token_embedding(token_inp)
  34. if pos_inp is not None:
  35. embed += self.pos_embedding(pos_inp)
  36. if type_inp is not None:
  37. embed += self.type_embedding(type_inp)
  38. if turn_inp is not None:
  39. embed += self.turn_embedding(turn_inp)
  40. embed = self.dropout_layer(embed)
  41. return embed
  42. def main():
  43. import numpy as np
  44. model = Embedder(10, 20, 20, 20, 20)
  45. token_inp = torch.tensor(
  46. np.random.randint(0, 19, [10, 10]).astype('int64'))
  47. pos_inp = torch.tensor(np.random.randint(0, 19, [10, 10]).astype('int64'))
  48. type_inp = torch.tensor(np.random.randint(0, 19, [10, 10]).astype('int64'))
  49. turn_inp = torch.tensor(np.random.randint(0, 19, [10, 10]).astype('int64'))
  50. out = model(token_inp, pos_inp, type_inp, turn_inp)
  51. print(out)
  52. if __name__ == '__main__':
  53. main()