voice.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import pickle as pkl
  4. import time
  5. from collections import OrderedDict
  6. from threading import Lock
  7. import json
  8. import numpy as np
  9. import torch
  10. import yaml
  11. from kantts.datasets.dataset import get_am_datasets, get_voc_datasets
  12. from kantts.models import model_builder
  13. from kantts.train.loss import criterion_builder
  14. from kantts.train.trainer import GAN_Trainer, Sambert_Trainer, distributed_init
  15. from kantts.utils.ling_unit.ling_unit import KanTtsLinguisticUnit
  16. from torch.utils.data import DataLoader
  17. from modelscope.utils.audio.audio_utils import TtsCustomParams
  18. from modelscope.utils.audio.tts_exceptions import (
  19. TtsModelConfigurationException, TtsModelNotExistsException)
  20. from modelscope.utils.logger import get_logger
  21. logger = get_logger()
  22. def count_parameters(model):
  23. return sum(p.numel() for p in model.parameters() if p.requires_grad)
  24. def denorm_f0(mel,
  25. f0_threshold=30,
  26. uv_threshold=0.6,
  27. norm_type='mean_std',
  28. f0_feature=None):
  29. if norm_type == 'mean_std':
  30. f0_mvn = f0_feature
  31. f0 = mel[:, -2]
  32. uv = mel[:, -1]
  33. uv[uv < uv_threshold] = 0.0
  34. uv[uv >= uv_threshold] = 1.0
  35. f0 = f0 * f0_mvn[1:, :] + f0_mvn[0:1, :]
  36. f0[f0 < f0_threshold] = f0_threshold
  37. mel[:, -2] = f0
  38. mel[:, -1] = uv
  39. else: # global
  40. f0_global_max_min = f0_feature
  41. f0 = mel[:, -2]
  42. uv = mel[:, -1]
  43. uv[uv < uv_threshold] = 0.0
  44. uv[uv >= uv_threshold] = 1.0
  45. f0 = f0 * (f0_global_max_min[0]
  46. - f0_global_max_min[1]) + f0_global_max_min[1]
  47. f0[f0 < f0_threshold] = f0_threshold
  48. mel[:, -2] = f0
  49. mel[:, -1] = uv
  50. return mel
  51. def binarize(mel, threshold=0.6):
  52. # vuv binarize
  53. res_mel = mel.clone()
  54. index = torch.where(mel[:, -1] < threshold)[0]
  55. res_mel[:, -1] = 1.0
  56. res_mel[:, -1][index] = 0.0
  57. return res_mel
  58. class Voice:
  59. def __init__(self,
  60. voice_name,
  61. voice_path=None,
  62. custom_ckpt={},
  63. ignore_mask=True,
  64. is_train=False):
  65. self.voice_name = voice_name
  66. self.voice_path = voice_path
  67. self.ignore_mask = ignore_mask
  68. self.is_train = is_train
  69. if not torch.cuda.is_available():
  70. self.device = torch.device('cpu')
  71. self.distributed = False
  72. else:
  73. torch.backends.cudnn.benchmark = True
  74. self.distributed, self.device, self.local_rank, self.world_size = distributed_init(
  75. )
  76. if len(custom_ckpt) != 0:
  77. self.am_config_path = custom_ckpt[TtsCustomParams.AM_CONFIG]
  78. self.voc_config_path = custom_ckpt[TtsCustomParams.VOC_CONFIG]
  79. if not os.path.isabs(self.am_config_path):
  80. self.am_config_path = os.path.join(voice_path,
  81. self.am_config_path)
  82. if not os.path.isabs(self.voc_config_path):
  83. self.voc_config_path = os.path.join(voice_path,
  84. self.voc_config_path)
  85. am_ckpt = custom_ckpt[TtsCustomParams.AM_CKPT]
  86. voc_ckpt = custom_ckpt[TtsCustomParams.VOC_CKPT]
  87. if not os.path.isabs(am_ckpt):
  88. am_ckpt = os.path.join(voice_path, am_ckpt)
  89. if not os.path.isabs(voc_ckpt):
  90. voc_ckpt = os.path.join(voice_path, voc_ckpt)
  91. self.am_ckpts = self.scan_ckpt(am_ckpt)
  92. self.voc_ckpts = self.scan_ckpt(voc_ckpt)
  93. self.se_path = custom_ckpt.get(TtsCustomParams.SE_FILE, 'se.npy')
  94. if not os.path.isabs(self.se_path):
  95. self.se_path = os.path.join(voice_path, self.se_path)
  96. self.se_model_path = custom_ckpt.get(TtsCustomParams.SE_MODEL,
  97. 'se.onnx')
  98. if not os.path.isabs(self.se_model_path):
  99. self.se_model_path = os.path.join(voice_path,
  100. self.se_model_path)
  101. self.audio_config = custom_ckpt.get(TtsCustomParams.AUIDO_CONFIG,
  102. 'audio_config.yaml')
  103. if not os.path.isabs(self.audio_config):
  104. self.audio_config = os.path.join(voice_path, self.audio_config)
  105. self.mvn_path = custom_ckpt.get(TtsCustomParams.MVN_FILE,
  106. 'mvn.npy')
  107. if not os.path.isabs(self.mvn_path):
  108. self.mvn_path = os.path.join(voice_path, self.mvn_path)
  109. else:
  110. self.audio_config = os.path.join(voice_path, 'audio_config.yaml')
  111. self.am_config_path = os.path.join(voice_path, 'am', 'config.yaml')
  112. self.voc_config_path = os.path.join(voice_path, 'voc',
  113. 'config.yaml')
  114. self.se_path = os.path.join(voice_path, 'am', 'se.npy')
  115. self.am_ckpts = self.scan_ckpt(
  116. os.path.join(voice_path, 'am', 'ckpt'))
  117. self.voc_ckpts = self.scan_ckpt(
  118. os.path.join(voice_path, 'voc', 'ckpt'))
  119. self.mvn_path = os.path.join(voice_path, 'am', 'mvn.npy')
  120. self.se_model_path = os.path.join(voice_path, 'se', 'ckpt',
  121. 'se.onnx')
  122. logger.info(
  123. f'am_config={self.am_config_path} voc_config={self.voc_config_path}'
  124. )
  125. logger.info(f'audio_config={self.audio_config}')
  126. logger.info(f'am_ckpts={self.am_ckpts}')
  127. logger.info(f'voc_ckpts={self.voc_ckpts}')
  128. logger.info(
  129. f'se_path={self.se_path} se_model_path={self.se_model_path}')
  130. logger.info(f'mvn_path={self.mvn_path}')
  131. if not os.path.exists(self.am_config_path):
  132. raise TtsModelConfigurationException(
  133. 'modelscope error: am configuration not found')
  134. if not os.path.exists(self.voc_config_path):
  135. raise TtsModelConfigurationException(
  136. 'modelscope error: voc configuration not found')
  137. if len(self.am_ckpts) == 0:
  138. raise TtsModelNotExistsException(
  139. 'modelscope error: am model file not found')
  140. if len(self.voc_ckpts) == 0:
  141. raise TtsModelNotExistsException(
  142. 'modelscope error: voc model file not found')
  143. with open(self.am_config_path, 'r') as f:
  144. self.am_config = yaml.load(f, Loader=yaml.Loader)
  145. with open(self.voc_config_path, 'r') as f:
  146. self.voc_config = yaml.load(f, Loader=yaml.Loader)
  147. if 'linguistic_unit' not in self.am_config:
  148. raise TtsModelConfigurationException(
  149. 'no linguistic_unit in am config')
  150. self.lang_type = self.am_config['linguistic_unit'].get(
  151. 'language', 'PinYin')
  152. self.model_loaded = False
  153. self.lock = Lock()
  154. self.ling_unit = KanTtsLinguisticUnit(self.am_config)
  155. self.ling_unit_size = self.ling_unit.get_unit_size()
  156. if self.ignore_mask:
  157. target_set = set(('sy', 'tone', 'syllable_flag', 'word_segment',
  158. 'emotion', 'speaker'))
  159. for k, v in self.ling_unit_size.items():
  160. if k in target_set:
  161. self.ling_unit_size[k] = v - 1
  162. self.am_config['Model']['KanTtsSAMBERT']['params'].update(
  163. self.ling_unit_size)
  164. self.se_enable = self.am_config['Model']['KanTtsSAMBERT'][
  165. 'params'].get('SE', False)
  166. if self.se_enable and not self.is_train:
  167. if not os.path.exists(self.se_path):
  168. raise TtsModelConfigurationException(
  169. f'se enabled but se_file:{self.se_path} not exists')
  170. self.se = np.load(self.se_path)
  171. self.nsf_enable = self.am_config['Model']['KanTtsSAMBERT'][
  172. 'params'].get('NSF', False)
  173. if self.nsf_enable and not self.is_train:
  174. self.nsf_norm_type = self.am_config['Model']['KanTtsSAMBERT'][
  175. 'params'].get('nsf_norm_type', 'mean_std')
  176. if self.nsf_norm_type == 'mean_std':
  177. if not os.path.exists(self.mvn_path):
  178. raise TtsModelNotExistsException(
  179. f'f0_mvn_file: {self.mvn_path} not exists')
  180. self.f0_feature = np.load(self.mvn_path)
  181. else: # global
  182. nsf_f0_global_minimum = self.am_config['Model'][
  183. 'KanTtsSAMBERT']['params'].get('nsf_f0_global_minimum',
  184. 30.0)
  185. nsf_f0_global_maximum = self.am_config['Model'][
  186. 'KanTtsSAMBERT']['params'].get('nsf_f0_global_maximum',
  187. 730.0)
  188. self.f0_feature = [
  189. nsf_f0_global_maximum, nsf_f0_global_minimum
  190. ]
  191. def scan_ckpt(self, ckpt_path):
  192. select_target = ckpt_path
  193. input_not_dir = False
  194. if not os.path.isdir(ckpt_path):
  195. input_not_dir = True
  196. ckpt_path = os.path.dirname(ckpt_path)
  197. filelist = os.listdir(ckpt_path)
  198. if len(filelist) == 0:
  199. return {}
  200. ckpts = {}
  201. for filename in filelist:
  202. # checkpoint_X.pth
  203. if len(filename) - 15 <= 0:
  204. continue
  205. if filename[-4:] == '.pth' and filename[0:10] == 'checkpoint':
  206. filename_prefix = filename.split('.')[0]
  207. idx = int(filename_prefix.split('_')[-1])
  208. path = os.path.join(ckpt_path, filename)
  209. if input_not_dir and path != select_target:
  210. continue
  211. ckpts[idx] = path
  212. od = OrderedDict(sorted(ckpts.items()))
  213. return od
  214. def load_am(self):
  215. self.am_model, _, _ = model_builder(self.am_config, self.device)
  216. self.am = self.am_model['KanTtsSAMBERT']
  217. state_dict = torch.load(
  218. self.am_ckpts[next(reversed(self.am_ckpts))],
  219. map_location=self.device)
  220. self.am.load_state_dict(state_dict['model'], strict=False)
  221. self.am.eval()
  222. def load_vocoder(self):
  223. from kantts.models.hifigan.hifigan import Generator
  224. self.voc_model = Generator(
  225. **self.voc_config['Model']['Generator']['params'])
  226. states = torch.load(
  227. self.voc_ckpts[next(reversed(self.voc_ckpts))],
  228. map_location=self.device)
  229. self.voc_model.load_state_dict(states['model']['generator'])
  230. if self.voc_config['Model']['Generator']['params']['out_channels'] > 1:
  231. from kantts.models.pqmf import PQMF
  232. self.voc_model = PQMF()
  233. self.voc_model.remove_weight_norm()
  234. self.voc_model.eval().to(self.device)
  235. def am_forward(self, symbol_seq):
  236. with self.lock:
  237. with torch.no_grad():
  238. inputs_feat_lst = self.ling_unit.encode_symbol_sequence(
  239. symbol_seq)
  240. inputs_feat_index = 0
  241. if self.ling_unit.using_byte():
  242. inputs_byte_index = (
  243. torch.from_numpy(
  244. inputs_feat_lst[inputs_feat_index]).long().to(
  245. self.device))
  246. inputs_ling = torch.stack([inputs_byte_index],
  247. dim=-1).unsqueeze(0)
  248. else:
  249. inputs_sy = (
  250. torch.from_numpy(
  251. inputs_feat_lst[inputs_feat_index]).long().to(
  252. self.device))
  253. inputs_feat_index = inputs_feat_index + 1
  254. inputs_tone = (
  255. torch.from_numpy(
  256. inputs_feat_lst[inputs_feat_index]).long().to(
  257. self.device))
  258. inputs_feat_index = inputs_feat_index + 1
  259. inputs_syllable = (
  260. torch.from_numpy(
  261. inputs_feat_lst[inputs_feat_index]).long().to(
  262. self.device))
  263. inputs_feat_index = inputs_feat_index + 1
  264. inputs_ws = (
  265. torch.from_numpy(
  266. inputs_feat_lst[inputs_feat_index]).long().to(
  267. self.device))
  268. inputs_ling = torch.stack(
  269. [inputs_sy, inputs_tone, inputs_syllable, inputs_ws],
  270. dim=-1).unsqueeze(0)
  271. inputs_feat_index = inputs_feat_index + 1
  272. inputs_emo = (
  273. torch.from_numpy(
  274. inputs_feat_lst[inputs_feat_index]).long().to(
  275. self.device).unsqueeze(0))
  276. inputs_feat_index = inputs_feat_index + 1
  277. if self.se_enable:
  278. inputs_spk = (
  279. torch.from_numpy(
  280. self.se.repeat(
  281. len(inputs_feat_lst[inputs_feat_index]),
  282. axis=0)).float().to(
  283. self.device).unsqueeze(0)[:, :-1, :])
  284. else:
  285. inputs_spk = (
  286. torch.from_numpy(
  287. inputs_feat_lst[inputs_feat_index]).long().to(
  288. self.device).unsqueeze(0)[:, :-1])
  289. inputs_len = (torch.zeros(1).to(self.device).long()
  290. + inputs_emo.size(1) - 1) # minus 1 for "~"
  291. res = self.am(inputs_ling[:, :-1, :], inputs_emo[:, :-1],
  292. inputs_spk, inputs_len)
  293. postnet_outputs = res['postnet_outputs']
  294. LR_length_rounded = res['LR_length_rounded']
  295. valid_length = int(LR_length_rounded[0].item())
  296. mel_post = postnet_outputs[0, :valid_length, :].cpu()
  297. if self.nsf_enable:
  298. mel_post = denorm_f0(
  299. mel_post,
  300. norm_type=self.nsf_norm_type,
  301. f0_feature=self.f0_feature)
  302. return mel_post
  303. def vocoder_forward(self, melspec):
  304. with torch.no_grad():
  305. x = melspec.to(self.device)
  306. if self.voc_model.nsf_enable:
  307. x = binarize(x)
  308. x = x.transpose(1, 0).unsqueeze(0)
  309. y = self.voc_model(x)
  310. if hasattr(self.voc_model, 'pqmf'):
  311. y = self.voc_model.synthesis(y)
  312. y = y.view(-1).cpu().numpy()
  313. return y
  314. def train_sambert(self,
  315. work_dir,
  316. stage_dir,
  317. data_dir,
  318. config_path,
  319. ignore_pretrain=False,
  320. hparams=dict()):
  321. logger.info('TRAIN SAMBERT....')
  322. if len(self.am_ckpts) == 0:
  323. raise TtsTrainingInvalidModelException(
  324. 'resume pretrain but model is empty')
  325. from_steps = hparams.get('resume_from_steps', -1)
  326. if from_steps < 0:
  327. from_latest = hparams.get('resume_from_latest', True)
  328. else:
  329. from_latest = hparams.get('resume_from_latest', False)
  330. train_steps = hparams.get('train_steps', 0)
  331. with open(self.audio_config, 'r') as f:
  332. config = yaml.load(f, Loader=yaml.Loader)
  333. with open(config_path, 'r') as f:
  334. config.update(yaml.load(f, Loader=yaml.Loader))
  335. config.update(hparams)
  336. resume_from = None
  337. if from_latest:
  338. from_steps = next(reversed(self.am_ckpts))
  339. resume_from = self.am_ckpts[from_steps]
  340. if not os.path.exists(resume_from):
  341. raise TtsTrainingInvalidModelException(
  342. f'latest model:{resume_from} not exists')
  343. else:
  344. if from_steps not in self.am_ckpts:
  345. raise TtsTrainingInvalidModelException(
  346. f'no such model from steps:{from_steps}')
  347. else:
  348. resume_from = self.am_ckpts[from_steps]
  349. if train_steps > 0:
  350. train_max_steps = train_steps + from_steps
  351. config['train_max_steps'] = train_max_steps
  352. logger.info(f'TRAINING steps: {train_max_steps}')
  353. config['create_time'] = time.strftime('%Y-%m-%d %H:%M:%S',
  354. time.localtime())
  355. from modelscope import __version__
  356. config['modelscope_version'] = __version__
  357. with open(os.path.join(stage_dir, 'config.yaml'), 'w') as f:
  358. yaml.dump(config, f, Dumper=yaml.Dumper, default_flow_style=None)
  359. for key, value in config.items():
  360. logger.info(f'{key} = {value}')
  361. if self.distributed:
  362. config['rank'] = torch.distributed.get_rank()
  363. config['distributed'] = True
  364. if self.se_enable:
  365. valid_enable = False
  366. valid_split_ratio = 0.00
  367. else:
  368. valid_enable = True
  369. valid_split_ratio = 0.02
  370. fp_enable = config['Model']['KanTtsSAMBERT']['params'].get('FP', False)
  371. meta_file = [
  372. os.path.join(
  373. d,
  374. 'raw_metafile.txt' if not fp_enable else 'fprm_metafile.txt')
  375. for d in data_dir
  376. ]
  377. train_dataset, valid_dataset = get_am_datasets(
  378. meta_file,
  379. data_dir,
  380. config,
  381. config['allow_cache'],
  382. split_ratio=1.0 - valid_split_ratio)
  383. logger.info(f'The number of training files = {len(train_dataset)}.')
  384. logger.info(f'The number of validation files = {len(valid_dataset)}.')
  385. sampler = {'train': None, 'valid': None}
  386. if self.distributed:
  387. # setup sampler for distributed training
  388. from torch.utils.data.distributed import DistributedSampler
  389. sampler['train'] = DistributedSampler(
  390. dataset=train_dataset,
  391. num_replicas=self.world_size,
  392. shuffle=True,
  393. )
  394. sampler['valid'] = DistributedSampler(
  395. dataset=valid_dataset,
  396. num_replicas=self.world_size,
  397. shuffle=False,
  398. ) if valid_enable else None
  399. train_dataloader = DataLoader(
  400. train_dataset,
  401. shuffle=False if self.distributed else True,
  402. collate_fn=train_dataset.collate_fn,
  403. batch_size=config['batch_size'],
  404. num_workers=config['num_workers'],
  405. sampler=sampler['train'],
  406. pin_memory=config['pin_memory'],
  407. )
  408. valid_dataloader = DataLoader(
  409. valid_dataset,
  410. shuffle=False if self.distributed else True,
  411. collate_fn=valid_dataset.collate_fn,
  412. batch_size=config['batch_size'],
  413. num_workers=config['num_workers'],
  414. sampler=sampler['valid'],
  415. pin_memory=config['pin_memory'],
  416. ) if valid_enable else None
  417. ling_unit_size = train_dataset.ling_unit.get_unit_size()
  418. config['Model']['KanTtsSAMBERT']['params'].update(ling_unit_size)
  419. model, optimizer, scheduler = model_builder(config, self.device,
  420. self.local_rank,
  421. self.distributed)
  422. criterion = criterion_builder(config, self.device)
  423. trainer = Sambert_Trainer(
  424. config=config,
  425. model=model,
  426. optimizer=optimizer,
  427. scheduler=scheduler,
  428. criterion=criterion,
  429. device=self.device,
  430. sampler=sampler,
  431. train_loader=train_dataloader,
  432. valid_loader=valid_dataloader,
  433. max_steps=train_max_steps,
  434. save_dir=stage_dir,
  435. save_interval=config['save_interval_steps'],
  436. valid_interval=config['eval_interval_steps'],
  437. log_interval=config['log_interval'],
  438. grad_clip=config['grad_norm'],
  439. )
  440. if resume_from is not None:
  441. trainer.load_checkpoint(resume_from, True, True)
  442. logger.info(f'Successfully resumed from {resume_from}.')
  443. try:
  444. trainer.train()
  445. except (Exception, KeyboardInterrupt) as e:
  446. logger.error(e, exc_info=True)
  447. trainer.save_checkpoint(
  448. os.path.join(
  449. os.path.join(stage_dir, 'ckpt'),
  450. f'checkpoint-{trainer.steps}.pth'))
  451. logger.info(
  452. f'Successfully saved checkpoint @ {trainer.steps}steps.')
  453. def train_hifigan(self,
  454. work_dir,
  455. stage_dir,
  456. data_dir,
  457. config_path,
  458. ignore_pretrain=False,
  459. hparams=dict()):
  460. logger.info('TRAIN HIFIGAN....')
  461. if len(self.voc_ckpts) == 0:
  462. raise TtsTrainingInvalidModelException(
  463. 'resume pretrain but model is empty')
  464. from_steps = hparams.get('resume_from_steps', -1)
  465. if from_steps < 0:
  466. from_latest = hparams.get('resume_from_latest', True)
  467. else:
  468. from_latest = hparams.get('resume_from_latest', False)
  469. train_steps = hparams.get('train_steps', 0)
  470. with open(self.audio_config, 'r') as f:
  471. config = yaml.load(f, Loader=yaml.Loader)
  472. with open(config_path, 'r') as f:
  473. config.update(yaml.load(f, Loader=yaml.Loader))
  474. config.update(hparams)
  475. resume_from = None
  476. if from_latest:
  477. from_steps = next(reversed(self.voc_ckpts))
  478. resume_from = self.voc_ckpts[from_steps]
  479. if not os.path.exists(resume_from):
  480. raise TtsTrainingInvalidModelException(
  481. f'latest model:{resume_from} not exists')
  482. else:
  483. if from_steps not in self.voc_ckpts:
  484. raise TtsTrainingInvalidModelException(
  485. f'no such model from steps:{from_steps}')
  486. else:
  487. resume_from = self.voc_ckpts[from_steps]
  488. if train_steps > 0:
  489. train_max_steps = train_steps
  490. config['train_max_steps'] = train_max_steps
  491. logger.info(f'TRAINING steps: {train_max_steps}')
  492. logger.info(f'resume from: {resume_from}')
  493. config['create_time'] = time.strftime('%Y-%m-%d %H:%M:%S',
  494. time.localtime())
  495. from modelscope import __version__
  496. config['modelscope_version'] = __version__
  497. with open(os.path.join(stage_dir, 'config.yaml'), 'w') as f:
  498. yaml.dump(config, f, Dumper=yaml.Dumper, default_flow_style=None)
  499. for key, value in config.items():
  500. logger.info(f'{key} = {value}')
  501. train_dataset, valid_dataset = get_voc_datasets(config, data_dir)
  502. logger.info(f'The number of training files = {len(train_dataset)}.')
  503. logger.info(f'The number of validation files = {len(valid_dataset)}.')
  504. sampler = {'train': None, 'valid': None}
  505. if self.distributed:
  506. # setup sampler for distributed training
  507. from torch.utils.data.distributed import DistributedSampler
  508. sampler['train'] = DistributedSampler(
  509. dataset=train_dataset,
  510. num_replicas=self.world_size,
  511. shuffle=True,
  512. )
  513. sampler['valid'] = DistributedSampler(
  514. dataset=valid_dataset,
  515. num_replicas=self.world_size,
  516. shuffle=False,
  517. )
  518. train_dataloader = DataLoader(
  519. train_dataset,
  520. shuffle=False if self.distributed else True,
  521. collate_fn=train_dataset.collate_fn,
  522. batch_size=config['batch_size'],
  523. num_workers=config['num_workers'],
  524. sampler=sampler['train'],
  525. pin_memory=config['pin_memory'],
  526. )
  527. valid_dataloader = DataLoader(
  528. valid_dataset,
  529. shuffle=False if self.distributed else True,
  530. collate_fn=valid_dataset.collate_fn,
  531. batch_size=config['batch_size'],
  532. num_workers=config['num_workers'],
  533. sampler=sampler['valid'],
  534. pin_memory=config['pin_memory'],
  535. )
  536. model, optimizer, scheduler = model_builder(config, self.device,
  537. self.local_rank,
  538. self.distributed)
  539. criterion = criterion_builder(config, self.device)
  540. trainer = GAN_Trainer(
  541. config=config,
  542. model=model,
  543. optimizer=optimizer,
  544. scheduler=scheduler,
  545. criterion=criterion,
  546. device=self.device,
  547. sampler=sampler,
  548. train_loader=train_dataloader,
  549. valid_loader=valid_dataloader,
  550. max_steps=train_max_steps,
  551. save_dir=stage_dir,
  552. save_interval=config['save_interval_steps'],
  553. valid_interval=config['eval_interval_steps'],
  554. log_interval=config['log_interval_steps'],
  555. )
  556. if resume_from is not None:
  557. trainer.load_checkpoint(resume_from)
  558. logger.info(f'Successfully resumed from {resume_from}.')
  559. try:
  560. trainer.train()
  561. except (Exception, KeyboardInterrupt) as e:
  562. logger.error(e, exc_info=True)
  563. trainer.save_checkpoint(
  564. os.path.join(
  565. os.path.join(stage_dir, 'ckpt'),
  566. f'checkpoint-{trainer.steps}.pth'))
  567. logger.info(
  568. f'Successfully saved checkpoint @ {trainer.steps}steps.')
  569. def forward(self, symbol_seq):
  570. with self.lock:
  571. if not self.model_loaded:
  572. self.load_am()
  573. self.load_vocoder()
  574. self.model_loaded = True
  575. return self.vocoder_forward(self.am_forward(symbol_seq))