activations.py 1.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import torch.nn as nn
  3. from .layer_base import LayerBase
  4. class RectifiedLinear(LayerBase):
  5. def __init__(self, input_dim, output_dim):
  6. super(RectifiedLinear, self).__init__()
  7. self.dim = input_dim
  8. self.relu = nn.ReLU()
  9. def forward(self, input):
  10. return self.relu(input)
  11. def to_kaldi_nnet(self):
  12. re_str = ''
  13. re_str += '<RectifiedLinear> %d %d\n' % (self.dim, self.dim)
  14. return re_str
  15. def load_kaldi_nnet(self, instr):
  16. return instr
  17. class LogSoftmax(LayerBase):
  18. def __init__(self, input_dim, output_dim):
  19. super(LogSoftmax, self).__init__()
  20. self.dim = input_dim
  21. self.ls = nn.LogSoftmax()
  22. def forward(self, input):
  23. return self.ls(input)
  24. def to_kaldi_nnet(self):
  25. re_str = ''
  26. re_str += '<Softmax> %d %d\n' % (self.dim, self.dim)
  27. return re_str
  28. def load_kaldi_nnet(self, instr):
  29. return instr
  30. class Sigmoid(LayerBase):
  31. def __init__(self, input_dim, output_dim):
  32. super(Sigmoid, self).__init__()
  33. self.dim = input_dim
  34. self.sig = nn.Sigmoid()
  35. def forward(self, input):
  36. return self.sig(input)
  37. def to_kaldi_nnet(self):
  38. re_str = ''
  39. re_str += '<Sigmoid> %d %d\n' % (self.dim, self.dim)
  40. return re_str
  41. def load_kaldi_nnet(self, instr):
  42. return instr