mossformer.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import copy
  3. import os
  4. from typing import Any, Dict
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. from modelscope.metainfo import Models
  9. from modelscope.models import MODELS, TorchModel
  10. from modelscope.models.audio.separation.mossformer_block import (
  11. MossFormerModule, ScaledSinuEmbedding)
  12. from modelscope.models.audio.separation.mossformer_conv_module import (
  13. CumulativeLayerNorm, GlobalLayerNorm)
  14. from modelscope.models.base import Tensor
  15. from modelscope.utils.constant import Tasks
  16. EPS = 1e-8
  17. @MODELS.register_module(
  18. Tasks.speech_separation,
  19. module_name=Models.speech_mossformer_separation_temporal_8k)
  20. class MossFormer(TorchModel):
  21. """Library to support MossFormer speech separation.
  22. Args:
  23. model_dir (str): the model path.
  24. """
  25. def __init__(self, model_dir: str, *args, **kwargs):
  26. super().__init__(model_dir, *args, **kwargs)
  27. self.encoder = Encoder(
  28. kernel_size=kwargs['kernel_size'],
  29. out_channels=kwargs['out_channels'])
  30. self.decoder = Decoder(
  31. in_channels=kwargs['in_channels'],
  32. out_channels=1,
  33. kernel_size=kwargs['kernel_size'],
  34. stride=kwargs['stride'],
  35. bias=kwargs['bias'])
  36. self.mask_net = MossFormerMaskNet(
  37. kwargs['in_channels'],
  38. kwargs['out_channels'],
  39. MossFormerM(kwargs['num_blocks'], kwargs['d_model'],
  40. kwargs['attn_dropout'], kwargs['group_size'],
  41. kwargs['query_key_dim'], kwargs['expansion_factor'],
  42. kwargs['causal']),
  43. norm=kwargs['norm'],
  44. num_spks=kwargs['num_spks'])
  45. self.num_spks = kwargs['num_spks']
  46. def forward(self, inputs: Tensor) -> Dict[str, Any]:
  47. # Separation
  48. mix_w = self.encoder(inputs)
  49. est_mask = self.mask_net(mix_w)
  50. mix_w = torch.stack([mix_w] * self.num_spks)
  51. sep_h = mix_w * est_mask
  52. # Decoding
  53. est_source = torch.cat(
  54. [
  55. self.decoder(sep_h[i]).unsqueeze(-1)
  56. for i in range(self.num_spks)
  57. ],
  58. dim=-1,
  59. )
  60. # T changed after conv1d in encoder, fix it here
  61. t_origin = inputs.size(1)
  62. t_est = est_source.size(1)
  63. if t_origin > t_est:
  64. est_source = F.pad(est_source, (0, 0, 0, t_origin - t_est))
  65. else:
  66. est_source = est_source[:, :t_origin, :]
  67. return est_source
  68. def load_check_point(self, load_path=None, device=None):
  69. if not load_path:
  70. load_path = self.model_dir
  71. if not device:
  72. device = torch.device('cpu')
  73. self.encoder.load_state_dict(
  74. torch.load(
  75. os.path.join(load_path, 'encoder.bin'), map_location=device),
  76. strict=True)
  77. self.decoder.load_state_dict(
  78. torch.load(
  79. os.path.join(load_path, 'decoder.bin'), map_location=device),
  80. strict=True)
  81. self.mask_net.load_state_dict(
  82. torch.load(
  83. os.path.join(load_path, 'masknet.bin'), map_location=device),
  84. strict=True)
  85. def as_dict(self):
  86. return dict(
  87. encoder=self.encoder, decoder=self.decoder, masknet=self.mask_net)
  88. def select_norm(norm, dim, shape):
  89. """Just a wrapper to select the normalization type.
  90. """
  91. if norm == 'gln':
  92. return GlobalLayerNorm(dim, shape, elementwise_affine=True)
  93. if norm == 'cln':
  94. return CumulativeLayerNorm(dim, elementwise_affine=True)
  95. if norm == 'ln':
  96. return nn.GroupNorm(1, dim, eps=1e-8)
  97. else:
  98. return nn.BatchNorm1d(dim)
  99. class Encoder(nn.Module):
  100. """Convolutional Encoder Layer.
  101. Args:
  102. kernel_size: Length of filters.
  103. in_channels: Number of input channels.
  104. out_channels: Number of output channels.
  105. Examples:
  106. >>> x = torch.randn(2, 1000)
  107. >>> encoder = Encoder(kernel_size=4, out_channels=64)
  108. >>> h = encoder(x)
  109. >>> h.shape # torch.Size([2, 64, 499])
  110. """
  111. def __init__(self,
  112. kernel_size: int = 2,
  113. out_channels: int = 64,
  114. in_channels: int = 1):
  115. super(Encoder, self).__init__()
  116. self.conv1d = nn.Conv1d(
  117. in_channels=in_channels,
  118. out_channels=out_channels,
  119. kernel_size=kernel_size,
  120. stride=kernel_size // 2,
  121. groups=1,
  122. bias=False,
  123. )
  124. self.in_channels = in_channels
  125. def forward(self, x: torch.Tensor):
  126. """Return the encoded output.
  127. Args:
  128. x: Input tensor with dimensionality [B, L].
  129. Returns:
  130. Encoded tensor with dimensionality [B, N, T_out].
  131. where B = Batchsize
  132. L = Number of timepoints
  133. N = Number of filters
  134. T_out = Number of timepoints at the output of the encoder
  135. """
  136. # B x L -> B x 1 x L
  137. if self.in_channels == 1:
  138. x = torch.unsqueeze(x, dim=1)
  139. # B x 1 x L -> B x N x T_out
  140. x = self.conv1d(x)
  141. x = F.relu(x)
  142. return x
  143. class Decoder(nn.ConvTranspose1d):
  144. """A decoder layer that consists of ConvTranspose1d.
  145. Args:
  146. kernel_size: Length of filters.
  147. in_channels: Number of input channels.
  148. out_channels: Number of output channels.
  149. Example
  150. ---------
  151. >>> x = torch.randn(2, 100, 1000)
  152. >>> decoder = Decoder(kernel_size=4, in_channels=100, out_channels=1)
  153. >>> h = decoder(x)
  154. >>> h.shape
  155. torch.Size([2, 1003])
  156. """
  157. def __init__(self, *args, **kwargs):
  158. super(Decoder, self).__init__(*args, **kwargs)
  159. def forward(self, x):
  160. """Return the decoded output.
  161. Args:
  162. x: Input tensor with dimensionality [B, N, L].
  163. where, B = Batchsize,
  164. N = number of filters
  165. L = time points
  166. """
  167. if x.dim() not in [2, 3]:
  168. raise RuntimeError('{} accept 3/4D tensor as input'.format(
  169. self.__name__))
  170. x = super().forward(x if x.dim() == 3 else torch.unsqueeze(x, 1))
  171. if torch.squeeze(x).dim() == 1:
  172. x = torch.squeeze(x, dim=1)
  173. else:
  174. x = torch.squeeze(x)
  175. return x
  176. class IdentityBlock:
  177. """This block is used when we want to have identity transformation within the Dual_path block.
  178. Example
  179. -------
  180. >>> x = torch.randn(10, 100)
  181. >>> IB = IdentityBlock()
  182. >>> xhat = IB(x)
  183. """
  184. def _init__(self, **kwargs):
  185. pass
  186. def __call__(self, x):
  187. return x
  188. class MossFormerM(nn.Module):
  189. """This class implements the transformer encoder.
  190. Args:
  191. num_blocks : int
  192. Number of mossformer blocks to include.
  193. d_model : int
  194. The dimension of the input embedding.
  195. attn_dropout : float
  196. Dropout for the self-attention (Optional).
  197. group_size: int
  198. the chunk size
  199. query_key_dim: int
  200. the attention vector dimension
  201. expansion_factor: int
  202. the expansion factor for the linear projection in conv module
  203. causal: bool
  204. true for causal / false for non causal
  205. Example
  206. -------
  207. >>> import torch
  208. >>> x = torch.rand((8, 60, 512)) #B, S, N
  209. >>> net = MossFormerM(num_blocks=8, d_model=512)
  210. >>> output, _ = net(x)
  211. >>> output.shape
  212. torch.Size([8, 60, 512])
  213. """
  214. def __init__(self,
  215. num_blocks,
  216. d_model=None,
  217. attn_dropout=0.1,
  218. group_size=256,
  219. query_key_dim=128,
  220. expansion_factor=4.,
  221. causal=False):
  222. super().__init__()
  223. self.mossformerM = MossFormerModule(
  224. dim=d_model,
  225. depth=num_blocks,
  226. group_size=group_size,
  227. query_key_dim=query_key_dim,
  228. expansion_factor=expansion_factor,
  229. causal=causal,
  230. attn_dropout=attn_dropout)
  231. import speechbrain as sb
  232. self.norm = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6)
  233. def forward(self, src: torch.Tensor):
  234. """
  235. Args:
  236. src: Tensor shape [B, S, N],
  237. where, B = Batchsize,
  238. S = time points
  239. N = number of filters
  240. The sequence to the encoder layer (required).
  241. """
  242. output = self.mossformerM(src)
  243. output = self.norm(output)
  244. return output
  245. class ComputeAttention(nn.Module):
  246. """Computation block for dual-path processing.
  247. Args:
  248. att_mdl : torch.nn.module
  249. Model to process within the chunks.
  250. out_channels : int
  251. Dimensionality of attention model.
  252. norm : str
  253. Normalization type.
  254. skip_connection : bool
  255. Skip connection around the attention module.
  256. Example
  257. ---------
  258. >>> att_block = MossFormerM(num_blocks=8, d_model=512)
  259. >>> comp_att = ComputeAttention(att_block, 512)
  260. >>> x = torch.randn(10, 64, 512)
  261. >>> x = comp_att(x)
  262. >>> x.shape
  263. torch.Size([10, 64, 512])
  264. """
  265. def __init__(
  266. self,
  267. att_mdl,
  268. out_channels,
  269. norm='ln',
  270. skip_connection=True,
  271. ):
  272. super(ComputeAttention, self).__init__()
  273. self.att_mdl = att_mdl
  274. self.skip_connection = skip_connection
  275. # Norm
  276. self.norm = norm
  277. if norm is not None:
  278. self.att_norm = select_norm(norm, out_channels, 3)
  279. def forward(self, x: torch.Tensor):
  280. """Returns the output tensor.
  281. Args:
  282. x: Input tensor of dimension [B, S, N].
  283. Returns:
  284. out: Output tensor of dimension [B, S, N].
  285. where, B = Batchsize,
  286. N = number of filters
  287. S = time points
  288. """
  289. # [B, S, N]
  290. att_out = x.permute(0, 2, 1).contiguous()
  291. att_out = self.att_mdl(att_out)
  292. # [B, N, S]
  293. att_out = att_out.permute(0, 2, 1).contiguous()
  294. if self.norm is not None:
  295. att_out = self.att_norm(att_out)
  296. # [B, N, S]
  297. if self.skip_connection:
  298. att_out = att_out + x
  299. out = att_out
  300. return out
  301. class MossFormerMaskNet(nn.Module):
  302. """The dual path model which is the basis for dualpathrnn, sepformer, dptnet.
  303. Args:
  304. in_channels : int
  305. Number of channels at the output of the encoder.
  306. out_channels : int
  307. Number of channels that would be inputted to the intra and inter blocks.
  308. att_model : torch.nn.module
  309. Attention model to process the input sequence.
  310. norm : str
  311. Normalization type.
  312. num_spks : int
  313. Number of sources (speakers).
  314. skip_connection : bool
  315. Skip connection around attention module.
  316. use_global_pos_enc : bool
  317. Global positional encodings.
  318. Example
  319. ---------
  320. >>> mossformer_block = MossFormerM(num_blocks=8, d_model=512)
  321. >>> mossformer_masknet = MossFormerMaskNet(64, 64, att_model, num_spks=2)
  322. >>> x = torch.randn(10, 64, 2000)
  323. >>> x = mossformer_masknet(x)
  324. >>> x.shape
  325. torch.Size([2, 10, 64, 2000])
  326. """
  327. def __init__(
  328. self,
  329. in_channels,
  330. out_channels,
  331. att_model,
  332. norm='ln',
  333. num_spks=2,
  334. skip_connection=True,
  335. use_global_pos_enc=True,
  336. ):
  337. super(MossFormerMaskNet, self).__init__()
  338. self.num_spks = num_spks
  339. self.norm = select_norm(norm, in_channels, 3)
  340. self.conv1d_encoder = nn.Conv1d(
  341. in_channels, out_channels, 1, bias=False)
  342. self.use_global_pos_enc = use_global_pos_enc
  343. if self.use_global_pos_enc:
  344. self.pos_enc = ScaledSinuEmbedding(out_channels)
  345. self.mdl = copy.deepcopy(
  346. ComputeAttention(
  347. att_model,
  348. out_channels,
  349. norm,
  350. skip_connection=skip_connection,
  351. ))
  352. self.conv1d_out = nn.Conv1d(
  353. out_channels, out_channels * num_spks, kernel_size=1)
  354. self.conv1_decoder = nn.Conv1d(
  355. out_channels, in_channels, 1, bias=False)
  356. self.prelu = nn.PReLU()
  357. self.activation = nn.ReLU()
  358. # gated output layer
  359. self.output = nn.Sequential(
  360. nn.Conv1d(out_channels, out_channels, 1), nn.Tanh())
  361. self.output_gate = nn.Sequential(
  362. nn.Conv1d(out_channels, out_channels, 1), nn.Sigmoid())
  363. def forward(self, x: torch.Tensor):
  364. """Returns the output tensor.
  365. Args:
  366. x: Input tensor of dimension [B, N, S].
  367. Returns:
  368. out: Output tensor of dimension [spks, B, N, S]
  369. where, spks = Number of speakers
  370. B = Batchsize,
  371. N = number of filters
  372. S = the number of time frames
  373. """
  374. # before each line we indicate the shape after executing the line
  375. # [B, N, L]
  376. x = self.norm(x)
  377. # [B, N, L]
  378. x = self.conv1d_encoder(x)
  379. if self.use_global_pos_enc:
  380. base = x
  381. x = x.transpose(1, -1)
  382. emb = self.pos_enc(x)
  383. emb = emb.transpose(0, -1)
  384. x = base + emb
  385. # [B, N, S]
  386. x = self.mdl(x)
  387. x = self.prelu(x)
  388. # [B, N*spks, S]
  389. x = self.conv1d_out(x)
  390. b, _, s = x.shape
  391. # [B*spks, N, S]
  392. x = x.view(b * self.num_spks, -1, s)
  393. # [B*spks, N, S]
  394. x = self.output(x) * self.output_gate(x)
  395. # [B*spks, N, S]
  396. x = self.conv1_decoder(x)
  397. # [B, spks, N, S]
  398. _, n, L = x.shape
  399. x = x.view(b, self.num_spks, n, L)
  400. x = self.activation(x)
  401. # [spks, B, N, S]
  402. x = x.transpose(0, 1)
  403. return x