sambert_hifi.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from __future__ import (absolute_import, division, print_function,
  3. unicode_literals)
  4. import datetime
  5. import os
  6. import shutil
  7. import sys
  8. import wave
  9. import zipfile
  10. import json
  11. import matplotlib.pyplot as plt
  12. import numpy as np
  13. import yaml
  14. from modelscope.metainfo import Models
  15. from modelscope.models.base import Model
  16. from modelscope.models.builder import MODELS
  17. from modelscope.utils.audio.audio_utils import (TtsCustomParams, TtsTrainType,
  18. ndarray_pcm_to_wav)
  19. from modelscope.utils.audio.tts_exceptions import (
  20. TtsFrontendInitializeFailedException,
  21. TtsFrontendLanguageTypeInvalidException, TtsModelConfigurationException,
  22. TtsVoiceNotExistsException)
  23. from modelscope.utils.constant import Tasks
  24. from modelscope.utils.logger import get_logger
  25. from .voice import Voice
  26. __all__ = ['SambertHifigan']
  27. logger = get_logger()
  28. @MODELS.register_module(
  29. Tasks.text_to_speech, module_name=Models.sambert_hifigan)
  30. class SambertHifigan(Model):
  31. def __init__(self, model_dir, *args, **kwargs):
  32. super().__init__(model_dir, *args, **kwargs)
  33. self.model_dir = model_dir
  34. self.sample_rate = kwargs.get('sample_rate', 16000)
  35. self.is_train = False
  36. if 'is_train' in kwargs:
  37. is_train = kwargs['is_train']
  38. if isinstance(is_train, bool):
  39. self.is_train = is_train
  40. # check legacy modelcard
  41. self.ignore_mask = False
  42. if 'am' in kwargs:
  43. if 'linguistic_unit' in kwargs['am']:
  44. self.ignore_mask = not kwargs['am']['linguistic_unit'].get(
  45. 'has_mask', True)
  46. self.voices, self.voice_cfg, self.lang_type = self.load_voice(
  47. model_dir, kwargs.get('custom_ckpt', {}))
  48. if len(self.voices) == 0 or len(self.voice_cfg.get('voices', [])) == 0:
  49. raise TtsVoiceNotExistsException('modelscope error: voices empty')
  50. if self.voice_cfg['voices']:
  51. self.default_voice_name = self.voice_cfg['voices'][0]
  52. else:
  53. raise TtsVoiceNotExistsException(
  54. 'modelscope error: voices is empty in voices.json')
  55. # initialize frontend
  56. if sys.version_info >= (3, 11):
  57. raise ImportError('Python version needs to be <= 3.10')
  58. import ttsfrd
  59. frontend = ttsfrd.TtsFrontendEngine()
  60. zip_file = os.path.join(model_dir, 'resource.zip')
  61. self.res_path = os.path.join(model_dir, 'resource')
  62. with zipfile.ZipFile(zip_file, 'r') as zip_ref:
  63. zip_ref.extractall(model_dir)
  64. if not frontend.initialize(self.res_path):
  65. raise TtsFrontendInitializeFailedException(
  66. 'modelscope error: resource invalid: {}'.format(self.res_path))
  67. if not frontend.set_lang_type(self.lang_type):
  68. raise TtsFrontendLanguageTypeInvalidException(
  69. 'modelscope error: language type invalid: {}'.format(
  70. self.lang_type))
  71. self.frontend = frontend
  72. def build_voice_from_custom(self, model_dir, custom_ckpt):
  73. necessary_files = (TtsCustomParams.VOICE_NAME, TtsCustomParams.AM_CKPT,
  74. TtsCustomParams.VOC_CKPT, TtsCustomParams.AM_CONFIG,
  75. TtsCustomParams.VOC_CONFIG)
  76. voices = {}
  77. voices_cfg = {}
  78. lang_type = 'PinYin'
  79. for k in necessary_files:
  80. if k not in custom_ckpt:
  81. raise TtsModelNotExistsException(
  82. f'custom ckpt must have: {necessary_files}')
  83. voice_name = custom_ckpt[TtsCustomParams.VOICE_NAME]
  84. voice = Voice(
  85. voice_name=voice_name,
  86. voice_path=model_dir,
  87. custom_ckpt=custom_ckpt,
  88. ignore_mask=self.ignore_mask,
  89. is_train=self.is_train)
  90. voices[voice_name] = voice
  91. voices_cfg['voices'] = [voice_name]
  92. lang_type = voice.lang_type
  93. return voices, voices_cfg, lang_type
  94. def load_voice(self, model_dir, custom_ckpt):
  95. voices = {}
  96. voices_path = os.path.join(model_dir, 'voices')
  97. voices_json_path = os.path.join(voices_path, 'voices.json')
  98. lang_type = 'PinYin'
  99. if len(custom_ckpt) != 0:
  100. return self.build_voice_from_custom(model_dir, custom_ckpt)
  101. if not os.path.exists(voices_path) or not os.path.exists(
  102. voices_json_path):
  103. return voices, {}, lang_type
  104. with open(voices_json_path, 'r', encoding='utf-8') as f:
  105. voice_cfg = json.load(f)
  106. if 'voices' not in voice_cfg:
  107. return voices, {}, lang_type
  108. for name in voice_cfg['voices']:
  109. voice_path = os.path.join(voices_path, name)
  110. if not os.path.exists(voice_path):
  111. continue
  112. voices[name] = Voice(
  113. name,
  114. voice_path,
  115. ignore_mask=self.ignore_mask,
  116. is_train=self.is_train)
  117. lang_type = voices[name].lang_type
  118. return voices, voice_cfg, lang_type
  119. def save_voices(self):
  120. voices_json_path = os.path.join(self.model_dir, 'voices',
  121. 'voices.json')
  122. if os.path.exists(voices_json_path):
  123. os.remove(voices_json_path)
  124. save_voices = {}
  125. save_voices['voices'] = []
  126. for k in self.voices.keys():
  127. save_voices['voices'].append(k)
  128. with open(voices_json_path, 'w', encoding='utf-8') as f:
  129. json.dump(save_voices, f)
  130. def get_voices(self):
  131. return self.voices, self.voice_cfg
  132. def create_empty_voice(self, voice_name, audio_config, am_config_path,
  133. voc_config_path):
  134. voice_name_path = os.path.join(self.model_dir, 'voices', voice_name)
  135. if os.path.exists(voice_name_path):
  136. shutil.rmtree(voice_name_path)
  137. os.makedirs(voice_name_path, exist_ok=True)
  138. if audio_config and os.path.exists(audio_config) and os.path.isfile(
  139. audio_config):
  140. shutil.copy(audio_config, voice_name_path)
  141. voice_am_path = os.path.join(voice_name_path, 'am')
  142. voice_voc_path = os.path.join(voice_name_path, 'voc')
  143. if am_config_path and os.path.exists(
  144. am_config_path) and os.path.isfile(am_config):
  145. am_config_name = os.path.join(voice_am_path, 'config.yaml')
  146. shutil.copy(am_config_path, am_config_name)
  147. if voc_config_path and os.path.exists(
  148. voc_config_path) and os.path.isfile(voc_config):
  149. voc_config_name = os.path.join(voice_am_path, 'config.yaml')
  150. shutil.copy(voc_config_path, voc_config_name)
  151. am_ckpt_path = os.path.join(voice_am_path, 'ckpt')
  152. voc_ckpt_path = os.path.join(voice_voc_path, 'ckpt')
  153. os.makedirs(am_ckpt_path, exist_ok=True)
  154. os.makedirs(voc_ckpt_path, exist_ok=True)
  155. self.voices[voice_name] = Voice(
  156. voice_name=voice_name,
  157. voice_path=voice_name_path,
  158. allow_empty=True)
  159. def get_voice_audio_config_path(self, voice):
  160. if voice not in self.voices:
  161. return ''
  162. return self.voices[voice].audio_config
  163. def get_voice_se_model_path(self, voice):
  164. if voice not in self.voices:
  165. return ''
  166. if self.voices[voice].se_enable:
  167. return self.voices[voice].se_model_path
  168. else:
  169. return ''
  170. def get_voice_lang_path(self, voice):
  171. if voice not in self.voices:
  172. return ''
  173. return self.voices[voice].lang_dir
  174. def synthesis_one_sentences(self, voice_name, text):
  175. if voice_name not in self.voices:
  176. raise TtsVoiceNotExistsException(
  177. f'modelscope error: Voice {voice_name} not exists')
  178. return self.voices[voice_name].forward(text)
  179. def train(self,
  180. voice,
  181. dirs,
  182. train_type,
  183. configs_path_dict=None,
  184. ignore_pretrain=False,
  185. create_if_not_exists=False,
  186. hparam=None):
  187. plt.set_loglevel('info')
  188. work_dir = dirs['work_dir']
  189. am_dir = dirs['am_tmp_dir']
  190. voc_dir = dirs['voc_tmp_dir']
  191. data_dir = dirs['data_dir']
  192. target_voice = None
  193. if voice not in self.voices:
  194. if not create_if_not_exists:
  195. raise TtsVoiceNotExistsException(
  196. f'modelscope error: Voice {voice_name} not exists')
  197. am_config_path = configs_path_dict.get('am_config',
  198. 'am_config.yaml')
  199. voc_config_path = configs_path_dict.get('voc_config',
  200. 'voc_config.yaml')
  201. if TtsTrainType.TRAIN_TYPE_SAMBERT in train_type and not am_config:
  202. raise TtsTrainingCfgNotExistsException(
  203. 'training new voice am with empty am_config')
  204. if TtsTrainType.TRAIN_TYPE_VOC in train_type and not voc_config:
  205. raise TtsTrainingCfgNotExistsException(
  206. 'training new voice voc with empty voc_config')
  207. else:
  208. target_voice = self.voices[voice]
  209. am_config_path = target_voice.am_config_path
  210. voc_config_path = target_voice.voc_config_path
  211. if configs_path_dict:
  212. if 'am_config' in configs_path_dict:
  213. am_override = configs_path_dict['am_config']
  214. if os.path.exists(am_override):
  215. am_config_path = am_override
  216. if 'voc_config' in configs_path_dict:
  217. voc_override = configs_path_dict['voc_config']
  218. if os.path.exists(voc_override):
  219. voc_config_path = voc_override
  220. logger.info('Start training....')
  221. if TtsTrainType.TRAIN_TYPE_SAMBERT in train_type:
  222. logger.info('Start SAMBERT training...')
  223. totaltime = datetime.datetime.now()
  224. hparams = train_type[TtsTrainType.TRAIN_TYPE_SAMBERT]
  225. target_voice.train_sambert(work_dir, am_dir, data_dir,
  226. am_config_path, ignore_pretrain,
  227. hparams)
  228. totaltime = datetime.datetime.now() - totaltime
  229. logger.info('SAMBERT training spent: {:.2f} hours\n'.format(
  230. totaltime.total_seconds() / 3600.0))
  231. else:
  232. logger.info('skip SAMBERT training...')
  233. if TtsTrainType.TRAIN_TYPE_VOC in train_type:
  234. logger.info('Start HIFIGAN training...')
  235. totaltime = datetime.datetime.now()
  236. hparams = train_type[TtsTrainType.TRAIN_TYPE_VOC]
  237. target_voice.train_hifigan(work_dir, voc_dir, data_dir,
  238. voc_config_path, ignore_pretrain,
  239. hparams)
  240. totaltime = datetime.datetime.now() - totaltime
  241. logger.info('HIFIGAN training spent: {:.2f} hours\n'.format(
  242. totaltime.total_seconds() / 3600.0))
  243. else:
  244. logger.info('skip HIFIGAN training...')
  245. def forward(self, text: str, voice_name: str = None):
  246. voice = self.default_voice_name
  247. if voice_name is not None:
  248. voice = voice_name
  249. result = self.frontend.gen_tacotron_symbols(text)
  250. texts = [s for s in result.splitlines() if s != '']
  251. audio_total = np.empty((0), dtype='int16')
  252. for line in texts:
  253. line = line.strip().split('\t')
  254. audio = self.synthesis_one_sentences(voice, line[1])
  255. audio = 32768.0 * audio
  256. audio_total = np.append(audio_total, audio.astype('int16'), axis=0)
  257. return ndarray_pcm_to_wav(self.sample_rate, audio_total)