pooling_layers.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. """ This implementation is adapted from https://github.com/wenet-e2e/wespeaker.
  3. """
  4. import torch
  5. import torch.nn as nn
  6. class TAP(nn.Module):
  7. """
  8. Temporal average pooling, only first-order mean is considered
  9. """
  10. def __init__(self, **kwargs):
  11. super(TAP, self).__init__()
  12. def forward(self, x):
  13. pooling_mean = x.mean(dim=-1)
  14. # To be compatible with 2D input
  15. pooling_mean = pooling_mean.flatten(start_dim=1)
  16. return pooling_mean
  17. class TSDP(nn.Module):
  18. """
  19. Temporal standard deviation pooling, only second-order std is considered
  20. """
  21. def __init__(self, **kwargs):
  22. super(TSDP, self).__init__()
  23. def forward(self, x):
  24. # The last dimension is the temporal axis
  25. pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-8)
  26. pooling_std = pooling_std.flatten(start_dim=1)
  27. return pooling_std
  28. class TSTP(nn.Module):
  29. """
  30. Temporal statistics pooling, concatenate mean and std, which is used in
  31. x-vector
  32. Comment: simple concatenation can not make full use of both statistics
  33. """
  34. def __init__(self, **kwargs):
  35. super(TSTP, self).__init__()
  36. def forward(self, x):
  37. # The last dimension is the temporal axis
  38. pooling_mean = x.mean(dim=-1)
  39. pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-8)
  40. pooling_mean = pooling_mean.flatten(start_dim=1)
  41. pooling_std = pooling_std.flatten(start_dim=1)
  42. stats = torch.cat((pooling_mean, pooling_std), 1)
  43. return stats
  44. class ASTP(nn.Module):
  45. """ Attentive statistics pooling: Channel- and context-dependent
  46. statistics pooling, first used in ECAPA_TDNN.
  47. """
  48. def __init__(self, in_dim, bottleneck_dim=128, global_context_att=False):
  49. super(ASTP, self).__init__()
  50. self.global_context_att = global_context_att
  51. # Use Conv1d with stride == 1 rather than Linear, then we don't
  52. # need to transpose inputs.
  53. if global_context_att:
  54. self.linear1 = nn.Conv1d(
  55. in_dim * 3, bottleneck_dim,
  56. kernel_size=1) # equals W and b in the paper
  57. else:
  58. self.linear1 = nn.Conv1d(
  59. in_dim, bottleneck_dim,
  60. kernel_size=1) # equals W and b in the paper
  61. self.linear2 = nn.Conv1d(
  62. bottleneck_dim, in_dim,
  63. kernel_size=1) # equals V and k in the paper
  64. def forward(self, x):
  65. """
  66. x: a 3-dimensional tensor in tdnn-based architecture (B,F,T)
  67. or a 4-dimensional tensor in resnet architecture (B,C,F,T)
  68. 0-dim: batch-dimension, last-dim: time-dimension (frame-dimension)
  69. """
  70. if len(x.shape) == 4:
  71. x = x.reshape(x.shape[0], x.shape[1] * x.shape[2], x.shape[3])
  72. assert len(x.shape) == 3
  73. if self.global_context_att:
  74. context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
  75. context_std = torch.sqrt(
  76. torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
  77. x_in = torch.cat((x, context_mean, context_std), dim=1)
  78. else:
  79. x_in = x
  80. # DON'T use ReLU here! ReLU may be hard to converge.
  81. alpha = torch.tanh(
  82. self.linear1(x_in)) # alpha = F.relu(self.linear1(x_in))
  83. alpha = torch.softmax(self.linear2(alpha), dim=2)
  84. mean = torch.sum(alpha * x, dim=2)
  85. var = torch.sum(alpha * (x**2), dim=2) - mean**2
  86. std = torch.sqrt(var.clamp(min=1e-10))
  87. return torch.cat([mean, std], dim=1)