fsmn.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import torch as th
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. class UniDeepFsmn(nn.Module):
  6. def __init__(self, input_dim, output_dim, lorder=None, hidden_size=None):
  7. super(UniDeepFsmn, self).__init__()
  8. self.input_dim = input_dim
  9. self.output_dim = output_dim
  10. if lorder is None:
  11. return
  12. self.lorder = lorder
  13. self.hidden_size = hidden_size
  14. self.linear = nn.Linear(input_dim, hidden_size)
  15. self.project = nn.Linear(hidden_size, output_dim, bias=False)
  16. self.conv1 = nn.Conv2d(
  17. output_dim,
  18. output_dim, [lorder + lorder - 1, 1], [1, 1],
  19. groups=output_dim,
  20. bias=False)
  21. def forward(self, input):
  22. f1 = F.relu(self.linear(input))
  23. p1 = self.project(f1)
  24. x = th.unsqueeze(p1, 1)
  25. x_per = x.permute(0, 3, 2, 1)
  26. y = F.pad(x_per, [0, 0, self.lorder - 1, self.lorder - 1])
  27. out = x_per + self.conv1(y)
  28. out1 = out.permute(0, 3, 2, 1)
  29. return input + out1.squeeze()
  30. class UniDeepFsmnDual(nn.Module):
  31. def __init__(self, input_dim, output_dim, lorder=None, hidden_size=None):
  32. super(UniDeepFsmnDual, self).__init__()
  33. self.input_dim = input_dim
  34. self.output_dim = output_dim
  35. if lorder is None:
  36. return
  37. self.lorder = lorder
  38. self.hidden_size = hidden_size
  39. self.linear = nn.Linear(input_dim, hidden_size)
  40. self.project = nn.Linear(hidden_size, output_dim, bias=False)
  41. self.conv1 = nn.Conv2d(
  42. output_dim,
  43. output_dim, [lorder + lorder - 1, 1], [1, 1],
  44. groups=output_dim,
  45. bias=False)
  46. self.conv2 = nn.Conv2d(
  47. output_dim,
  48. output_dim, [lorder + lorder - 1, 1], [1, 1],
  49. groups=output_dim // 4,
  50. bias=False)
  51. def forward(self, input):
  52. f1 = F.relu(self.linear(input))
  53. p1 = self.project(f1)
  54. x = th.unsqueeze(p1, 1)
  55. x_per = x.permute(0, 3, 2, 1)
  56. y = F.pad(x_per, [0, 0, self.lorder - 1, self.lorder - 1])
  57. conv1_out = x_per + self.conv1(y)
  58. z = F.pad(conv1_out, [0, 0, self.lorder - 1, self.lorder - 1])
  59. out = conv1_out + self.conv2(z)
  60. out1 = out.permute(0, 3, 2, 1)
  61. return input + out1.squeeze()
  62. class DilatedDenseNet(nn.Module):
  63. def __init__(self, depth=4, lorder=20, in_channels=64):
  64. super(DilatedDenseNet, self).__init__()
  65. self.depth = depth
  66. self.in_channels = in_channels
  67. self.pad = nn.ConstantPad2d((1, 1, 1, 0), value=0.)
  68. self.twidth = lorder * 2 - 1
  69. self.kernel_size = (self.twidth, 1)
  70. for i in range(self.depth):
  71. dil = 2**i
  72. pad_length = lorder + (dil - 1) * (lorder - 1) - 1
  73. setattr(self, 'pad{}'.format(i + 1),
  74. nn.ConstantPad2d((0, 0, pad_length, pad_length), value=0.))
  75. setattr(
  76. self, 'conv{}'.format(i + 1),
  77. nn.Conv2d(
  78. self.in_channels * (i + 1),
  79. self.in_channels,
  80. kernel_size=self.kernel_size,
  81. dilation=(dil, 1),
  82. groups=self.in_channels,
  83. bias=False))
  84. setattr(self, 'norm{}'.format(i + 1),
  85. nn.InstanceNorm2d(in_channels, affine=True))
  86. setattr(self, 'prelu{}'.format(i + 1), nn.PReLU(self.in_channels))
  87. def forward(self, x):
  88. skip = x
  89. for i in range(self.depth):
  90. out = getattr(self, 'pad{}'.format(i + 1))(skip)
  91. out = getattr(self, 'conv{}'.format(i + 1))(out)
  92. out = getattr(self, 'norm{}'.format(i + 1))(out)
  93. out = getattr(self, 'prelu{}'.format(i + 1))(out)
  94. skip = th.cat([out, skip], dim=1)
  95. return out
  96. class UniDeepFsmnDilated(nn.Module):
  97. def __init__(self,
  98. input_dim,
  99. output_dim,
  100. lorder=None,
  101. hidden_size=None,
  102. depth=2):
  103. super(UniDeepFsmnDilated, self).__init__()
  104. self.input_dim = input_dim
  105. self.output_dim = output_dim
  106. self.depth = depth
  107. if lorder is None:
  108. return
  109. self.lorder = lorder
  110. self.hidden_size = hidden_size
  111. self.linear = nn.Linear(input_dim, hidden_size)
  112. self.project = nn.Linear(hidden_size, output_dim, bias=False)
  113. self.conv = DilatedDenseNet(
  114. depth=self.depth, lorder=lorder, in_channels=output_dim)
  115. def forward(self, input):
  116. f1 = F.relu(self.linear(input))
  117. p1 = self.project(f1)
  118. x = th.unsqueeze(p1, 1)
  119. x_per = x.permute(0, 3, 2, 1)
  120. out = self.conv(x_per)
  121. out1 = out.permute(0, 3, 2, 1)
  122. return input + out1.squeeze()