fsmn.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import numpy as np
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from .model_def import (HEADER_BLOCK_SIZE, ActivationType, LayerType, f32ToI32,
  7. printNeonMatrix, printNeonVector)
  8. DEBUG = False
  9. def to_kaldi_matrix(np_mat):
  10. """ function that transform as str numpy mat to standard kaldi str matrix
  11. Args:
  12. np_mat: numpy mat
  13. Returns: str
  14. """
  15. np.set_printoptions(threshold=np.inf, linewidth=np.nan)
  16. out_str = str(np_mat)
  17. out_str = out_str.replace('[', '')
  18. out_str = out_str.replace(']', '')
  19. return '[ %s ]\n' % out_str
  20. def print_tensor(torch_tensor):
  21. """ print torch tensor for debug
  22. Args:
  23. torch_tensor: a tensor
  24. """
  25. re_str = ''
  26. x = torch_tensor.detach().squeeze().numpy()
  27. re_str += to_kaldi_matrix(x)
  28. re_str += '<!EndOfComponent>\n'
  29. print(re_str)
  30. class LinearTransform(nn.Module):
  31. def __init__(self, input_dim, output_dim):
  32. super(LinearTransform, self).__init__()
  33. self.input_dim = input_dim
  34. self.output_dim = output_dim
  35. self.linear = nn.Linear(input_dim, output_dim, bias=False)
  36. self.debug = False
  37. self.dataout = None
  38. def forward(self, input):
  39. output = self.linear(input)
  40. if self.debug:
  41. self.dataout = output
  42. return output
  43. def print_model(self):
  44. printNeonMatrix(self.linear.weight)
  45. def to_kaldi_nnet(self):
  46. re_str = ''
  47. re_str += '<LinearTransform> %d %d\n' % (self.output_dim,
  48. self.input_dim)
  49. re_str += '<LearnRateCoef> 1\n'
  50. linear_weights = self.state_dict()['linear.weight']
  51. x = linear_weights.squeeze().numpy()
  52. re_str += to_kaldi_matrix(x)
  53. re_str += '<!EndOfComponent>\n'
  54. return re_str
  55. class AffineTransform(nn.Module):
  56. def __init__(self, input_dim, output_dim):
  57. super(AffineTransform, self).__init__()
  58. self.input_dim = input_dim
  59. self.output_dim = output_dim
  60. self.linear = nn.Linear(input_dim, output_dim)
  61. self.debug = False
  62. self.dataout = None
  63. def forward(self, input):
  64. output = self.linear(input)
  65. if self.debug:
  66. self.dataout = output
  67. return output
  68. def print_model(self):
  69. printNeonMatrix(self.linear.weight)
  70. printNeonVector(self.linear.bias)
  71. def to_kaldi_nnet(self):
  72. re_str = ''
  73. re_str += '<AffineTransform> %d %d\n' % (self.output_dim,
  74. self.input_dim)
  75. re_str += '<LearnRateCoef> 1 <BiasLearnRateCoef> 1 <MaxNorm> 0\n'
  76. linear_weights = self.state_dict()['linear.weight']
  77. x = linear_weights.squeeze().numpy()
  78. re_str += to_kaldi_matrix(x)
  79. linear_bias = self.state_dict()['linear.bias']
  80. x = linear_bias.squeeze().numpy()
  81. re_str += to_kaldi_matrix(x)
  82. re_str += '<!EndOfComponent>\n'
  83. return re_str
  84. class Fsmn(nn.Module):
  85. """
  86. FSMN implementation.
  87. """
  88. def __init__(self,
  89. input_dim,
  90. output_dim,
  91. lorder=None,
  92. rorder=None,
  93. lstride=None,
  94. rstride=None):
  95. super(Fsmn, self).__init__()
  96. self.dim = input_dim
  97. if lorder is None:
  98. return
  99. self.lorder = lorder
  100. self.rorder = rorder
  101. self.lstride = lstride
  102. self.rstride = rstride
  103. self.conv_left = nn.Conv2d(
  104. self.dim,
  105. self.dim, (lorder, 1),
  106. dilation=(lstride, 1),
  107. groups=self.dim,
  108. bias=False)
  109. if rorder > 0:
  110. self.conv_right = nn.Conv2d(
  111. self.dim,
  112. self.dim, (rorder, 1),
  113. dilation=(rstride, 1),
  114. groups=self.dim,
  115. bias=False)
  116. else:
  117. self.conv_right = None
  118. self.debug = False
  119. self.dataout = None
  120. def forward(self, input):
  121. x = torch.unsqueeze(input, 1)
  122. x_per = x.permute(0, 3, 2, 1)
  123. y_left = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride, 0])
  124. if self.conv_right is not None:
  125. y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride])
  126. y_right = y_right[:, :, self.rstride:, :]
  127. out = x_per + self.conv_left(y_left) + self.conv_right(y_right)
  128. else:
  129. out = x_per + self.conv_left(y_left)
  130. out1 = out.permute(0, 3, 2, 1)
  131. output = out1.squeeze(1)
  132. if self.debug:
  133. self.dataout = output
  134. return output
  135. def print_model(self):
  136. tmpw = self.conv_left.weight
  137. tmpwm = torch.zeros(tmpw.shape[2], tmpw.shape[0])
  138. for j in range(tmpw.shape[0]):
  139. tmpwm[:, j] = tmpw[j, 0, :, 0]
  140. printNeonMatrix(tmpwm)
  141. if self.conv_right is not None:
  142. tmpw = self.conv_right.weight
  143. tmpwm = torch.zeros(tmpw.shape[2], tmpw.shape[0])
  144. for j in range(tmpw.shape[0]):
  145. tmpwm[:, j] = tmpw[j, 0, :, 0]
  146. printNeonMatrix(tmpwm)
  147. def to_kaldi_nnet(self):
  148. re_str = ''
  149. re_str += '<Fsmn> %d %d\n' % (self.dim, self.dim)
  150. re_str += '<LearnRateCoef> %d <LOrder> %d <ROrder> %d <LStride> %d <RStride> %d <MaxNorm> 0\n' % (
  151. 1, self.lorder, self.rorder, self.lstride, self.rstride)
  152. lfiters = self.state_dict()['conv_left.weight']
  153. x = np.flipud(lfiters.squeeze().numpy().T)
  154. re_str += to_kaldi_matrix(x)
  155. if self.conv_right is not None:
  156. rfiters = self.state_dict()['conv_right.weight']
  157. x = (rfiters.squeeze().numpy().T)
  158. re_str += to_kaldi_matrix(x)
  159. re_str += '<!EndOfComponent>\n'
  160. return re_str
  161. class RectifiedLinear(nn.Module):
  162. def __init__(self, input_dim, output_dim):
  163. super(RectifiedLinear, self).__init__()
  164. self.dim = input_dim
  165. self.relu = nn.ReLU()
  166. def forward(self, input):
  167. return self.relu(input)
  168. def to_kaldi_nnet(self):
  169. re_str = ''
  170. re_str += '<RectifiedLinear> %d %d\n' % (self.dim, self.dim)
  171. re_str += '<!EndOfComponent>\n'
  172. return re_str
  173. class FSMNNet(nn.Module):
  174. """
  175. FSMN net for keyword spotting
  176. """
  177. def __init__(self,
  178. input_dim=200,
  179. linear_dim=128,
  180. proj_dim=128,
  181. lorder=10,
  182. rorder=1,
  183. num_syn=5,
  184. fsmn_layers=4):
  185. """
  186. Args:
  187. input_dim: input dimension
  188. linear_dim: fsmn input dimension
  189. proj_dim: fsmn projection dimension
  190. lorder: fsmn left order
  191. rorder: fsmn right order
  192. num_syn: output dimension
  193. fsmn_layers: no. of sequential fsmn layers
  194. """
  195. super(FSMNNet, self).__init__()
  196. self.input_dim = input_dim
  197. self.linear_dim = linear_dim
  198. self.proj_dim = proj_dim
  199. self.lorder = lorder
  200. self.rorder = rorder
  201. self.num_syn = num_syn
  202. self.fsmn_layers = fsmn_layers
  203. self.linear1 = AffineTransform(input_dim, linear_dim)
  204. self.relu = RectifiedLinear(linear_dim, linear_dim)
  205. self.fsmn = self._build_repeats(linear_dim, proj_dim, lorder, rorder,
  206. fsmn_layers)
  207. self.linear2 = AffineTransform(linear_dim, num_syn)
  208. @staticmethod
  209. def _build_repeats(linear_dim=136,
  210. proj_dim=68,
  211. lorder=3,
  212. rorder=2,
  213. fsmn_layers=5):
  214. repeats = [
  215. nn.Sequential(
  216. LinearTransform(linear_dim, proj_dim),
  217. Fsmn(proj_dim, proj_dim, lorder, rorder, 1, 1),
  218. AffineTransform(proj_dim, linear_dim),
  219. RectifiedLinear(linear_dim, linear_dim))
  220. for i in range(fsmn_layers)
  221. ]
  222. return nn.Sequential(*repeats)
  223. def forward(self, input):
  224. x1 = self.linear1(input)
  225. x2 = self.relu(x1)
  226. x3 = self.fsmn(x2)
  227. x4 = self.linear2(x3)
  228. return x4
  229. def print_model(self):
  230. self.linear1.print_model()
  231. for layer in self.fsmn:
  232. layer[0].print_model()
  233. layer[1].print_model()
  234. layer[2].print_model()
  235. self.linear2.print_model()
  236. def print_header(self):
  237. #
  238. # write total header
  239. #
  240. header = [0.0] * HEADER_BLOCK_SIZE * 4
  241. # numins
  242. header[0] = 0.0
  243. # numouts
  244. header[1] = 0.0
  245. # dimins
  246. header[2] = self.input_dim
  247. # dimouts
  248. header[3] = self.num_syn
  249. # numlayers
  250. header[4] = 3
  251. #
  252. # write each layer's header
  253. #
  254. hidx = 1
  255. header[HEADER_BLOCK_SIZE * hidx + 0] = float(
  256. LayerType.LAYER_DENSE.value)
  257. header[HEADER_BLOCK_SIZE * hidx + 1] = 0.0
  258. header[HEADER_BLOCK_SIZE * hidx + 2] = self.input_dim
  259. header[HEADER_BLOCK_SIZE * hidx + 3] = self.linear_dim
  260. header[HEADER_BLOCK_SIZE * hidx + 4] = 1.0
  261. header[HEADER_BLOCK_SIZE * hidx + 5] = float(
  262. ActivationType.ACTIVATION_RELU.value)
  263. hidx += 1
  264. header[HEADER_BLOCK_SIZE * hidx + 0] = float(
  265. LayerType.LAYER_SEQUENTIAL_FSMN.value)
  266. header[HEADER_BLOCK_SIZE * hidx + 1] = 0.0
  267. header[HEADER_BLOCK_SIZE * hidx + 2] = self.linear_dim
  268. header[HEADER_BLOCK_SIZE * hidx + 3] = self.proj_dim
  269. header[HEADER_BLOCK_SIZE * hidx + 4] = self.lorder
  270. header[HEADER_BLOCK_SIZE * hidx + 5] = self.rorder
  271. header[HEADER_BLOCK_SIZE * hidx + 6] = self.fsmn_layers
  272. header[HEADER_BLOCK_SIZE * hidx + 7] = -1.0
  273. hidx += 1
  274. header[HEADER_BLOCK_SIZE * hidx + 0] = float(
  275. LayerType.LAYER_DENSE.value)
  276. header[HEADER_BLOCK_SIZE * hidx + 1] = 0.0
  277. header[HEADER_BLOCK_SIZE * hidx + 2] = self.linear_dim
  278. header[HEADER_BLOCK_SIZE * hidx + 3] = self.num_syn
  279. header[HEADER_BLOCK_SIZE * hidx + 4] = 1.0
  280. header[HEADER_BLOCK_SIZE * hidx + 5] = float(
  281. ActivationType.ACTIVATION_SOFTMAX.value)
  282. for h in header:
  283. print(f32ToI32(h))
  284. def to_kaldi_nnet(self):
  285. re_str = ''
  286. re_str += '<Nnet>\n'
  287. re_str += self.linear1.to_kaldi_nnet()
  288. re_str += self.relu.to_kaldi_nnet()
  289. for fsmn in self.fsmn:
  290. re_str += fsmn[0].to_kaldi_nnet()
  291. re_str += fsmn[1].to_kaldi_nnet()
  292. re_str += fsmn[2].to_kaldi_nnet()
  293. re_str += fsmn[3].to_kaldi_nnet()
  294. re_str += self.linear2.to_kaldi_nnet()
  295. re_str += '<Softmax> %d %d\n' % (self.num_syn, self.num_syn)
  296. re_str += '<!EndOfComponent>\n'
  297. re_str += '</Nnet>\n'
  298. return re_str
  299. class DFSMN(nn.Module):
  300. """
  301. One deep fsmn layer
  302. """
  303. def __init__(self,
  304. dimproj=64,
  305. dimlinear=128,
  306. lorder=20,
  307. rorder=1,
  308. lstride=1,
  309. rstride=1):
  310. """
  311. Args:
  312. dimproj: projection dimension, input and output dimension of memory blocks
  313. dimlinear: dimension of mapping layer
  314. lorder: left order
  315. rorder: right order
  316. lstride: left stride
  317. rstride: right stride
  318. """
  319. super(DFSMN, self).__init__()
  320. self.lorder = lorder
  321. self.rorder = rorder
  322. self.lstride = lstride
  323. self.rstride = rstride
  324. self.expand = AffineTransform(dimproj, dimlinear)
  325. self.shrink = LinearTransform(dimlinear, dimproj)
  326. self.conv_left = nn.Conv2d(
  327. dimproj,
  328. dimproj, (lorder, 1),
  329. dilation=(lstride, 1),
  330. groups=dimproj,
  331. bias=False)
  332. if rorder > 0:
  333. self.conv_right = nn.Conv2d(
  334. dimproj,
  335. dimproj, (rorder, 1),
  336. dilation=(rstride, 1),
  337. groups=dimproj,
  338. bias=False)
  339. else:
  340. self.conv_right = None
  341. def forward(self, input):
  342. f1 = F.relu(self.expand(input))
  343. p1 = self.shrink(f1)
  344. x = torch.unsqueeze(p1, 1)
  345. x_per = x.permute(0, 3, 2, 1)
  346. y_left = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride, 0])
  347. if self.conv_right is not None:
  348. y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride])
  349. y_right = y_right[:, :, self.rstride:, :]
  350. out = x_per + self.conv_left(y_left) + self.conv_right(y_right)
  351. else:
  352. out = x_per + self.conv_left(y_left)
  353. out1 = out.permute(0, 3, 2, 1)
  354. output = input + out1.squeeze(1)
  355. return output
  356. def print_model(self):
  357. self.expand.print_model()
  358. self.shrink.print_model()
  359. tmpw = self.conv_left.weight
  360. tmpwm = torch.zeros(tmpw.shape[2], tmpw.shape[0])
  361. for j in range(tmpw.shape[0]):
  362. tmpwm[:, j] = tmpw[j, 0, :, 0]
  363. printNeonMatrix(tmpwm)
  364. if self.conv_right is not None:
  365. tmpw = self.conv_right.weight
  366. tmpwm = torch.zeros(tmpw.shape[2], tmpw.shape[0])
  367. for j in range(tmpw.shape[0]):
  368. tmpwm[:, j] = tmpw[j, 0, :, 0]
  369. printNeonMatrix(tmpwm)
  370. def build_dfsmn_repeats(linear_dim=128,
  371. proj_dim=64,
  372. lorder=20,
  373. rorder=1,
  374. fsmn_layers=6):
  375. """
  376. build stacked dfsmn layers
  377. Args:
  378. linear_dim:
  379. proj_dim:
  380. lorder:
  381. rorder:
  382. fsmn_layers:
  383. Returns:
  384. """
  385. repeats = [
  386. nn.Sequential(DFSMN(proj_dim, linear_dim, lorder, rorder, 1, 1))
  387. for i in range(fsmn_layers)
  388. ]
  389. return nn.Sequential(*repeats)