tts_trainer.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import shutil
  4. import tempfile
  5. import zipfile
  6. from typing import Callable, Dict, List, Optional, Tuple, Union
  7. import json
  8. from modelscope.metainfo import Preprocessors, Trainers
  9. from modelscope.models import Model
  10. from modelscope.models.audio.tts import SambertHifigan
  11. from modelscope.msdatasets import MsDataset
  12. from modelscope.preprocessors.builder import build_preprocessor
  13. from modelscope.trainers.base import BaseTrainer
  14. from modelscope.trainers.builder import TRAINERS
  15. from modelscope.utils.audio.audio_utils import TtsTrainType
  16. from modelscope.utils.audio.tts_exceptions import (
  17. TtsTrainingCfgNotExistsException, TtsTrainingDatasetInvalidException,
  18. TtsTrainingHparamsInvalidException, TtsTrainingInvalidModelException,
  19. TtsTrainingWorkDirNotExistsException)
  20. from modelscope.utils.config import Config
  21. from modelscope.utils.constant import (DEFAULT_DATASET_NAMESPACE,
  22. DEFAULT_DATASET_REVISION,
  23. DEFAULT_MODEL_REVISION, ModelFile,
  24. Tasks, TrainerStages)
  25. from modelscope.utils.data_utils import to_device
  26. from modelscope.utils.logger import get_logger
  27. logger = get_logger()
  28. @TRAINERS.register_module(module_name=Trainers.speech_kantts_trainer)
  29. class KanttsTrainer(BaseTrainer):
  30. DATA_DIR = 'data'
  31. AM_TMP_DIR = 'tmp_am'
  32. VOC_TMP_DIR = 'tmp_voc'
  33. ORIG_MODEL_DIR = 'orig_model'
  34. def __init__(self,
  35. model: Union[Model, str],
  36. work_dir: str = None,
  37. speaker: str = 'F7',
  38. lang_type: str = 'PinYin',
  39. cfg_file: str = None,
  40. train_dataset: Union[MsDataset, str] = None,
  41. train_dataset_namespace: str = DEFAULT_DATASET_NAMESPACE,
  42. train_dataset_revision: str = DEFAULT_DATASET_REVISION,
  43. train_type: dict = {
  44. TtsTrainType.TRAIN_TYPE_SAMBERT: {},
  45. TtsTrainType.TRAIN_TYPE_VOC: {}
  46. },
  47. preprocess_skip_script=False,
  48. model_revision: Optional[str] = DEFAULT_MODEL_REVISION,
  49. **kwargs):
  50. if not work_dir:
  51. self.work_dir = tempfile.TemporaryDirectory().name
  52. if not os.path.exists(self.work_dir):
  53. os.makedirs(self.work_dir)
  54. else:
  55. self.work_dir = work_dir
  56. if not os.path.exists(self.work_dir):
  57. raise TtsTrainingWorkDirNotExistsException(
  58. f'{self.work_dir} not exists')
  59. self.train_type = dict()
  60. if isinstance(train_type, dict):
  61. for k, v in train_type.items():
  62. if (k == TtsTrainType.TRAIN_TYPE_SAMBERT
  63. or k == TtsTrainType.TRAIN_TYPE_VOC
  64. or k == TtsTrainType.TRAIN_TYPE_BERT):
  65. self.train_type[k] = v
  66. if len(self.train_type) == 0:
  67. logger.info('train type empty, default to sambert and voc')
  68. self.train_type[TtsTrainType.TRAIN_TYPE_SAMBERT] = {}
  69. self.train_type[TtsTrainType.TRAIN_TYPE_VOC] = {}
  70. logger.info(f'Set workdir to {self.work_dir}')
  71. self.data_dir = os.path.join(self.work_dir, self.DATA_DIR)
  72. self.am_tmp_dir = os.path.join(self.work_dir, self.AM_TMP_DIR)
  73. self.voc_tmp_dir = os.path.join(self.work_dir, self.VOC_TMP_DIR)
  74. self.orig_model_dir = os.path.join(self.work_dir, self.ORIG_MODEL_DIR)
  75. self.raw_dataset_path = ''
  76. self.skip_script = preprocess_skip_script
  77. self.audio_config_path = ''
  78. self.am_config_path = ''
  79. self.voc_config_path = ''
  80. shutil.rmtree(self.data_dir, ignore_errors=True)
  81. shutil.rmtree(self.am_tmp_dir, ignore_errors=True)
  82. shutil.rmtree(self.voc_tmp_dir, ignore_errors=True)
  83. shutil.rmtree(self.orig_model_dir, ignore_errors=True)
  84. os.makedirs(self.data_dir)
  85. os.makedirs(self.am_tmp_dir)
  86. os.makedirs(self.voc_tmp_dir)
  87. if train_dataset:
  88. if isinstance(train_dataset, str):
  89. if os.path.exists(train_dataset):
  90. logger.info(f'load {train_dataset}')
  91. self.raw_dataset_path = train_dataset
  92. else:
  93. logger.info(
  94. f'load {train_dataset_namespace}/{train_dataset}')
  95. train_dataset = MsDataset.load(
  96. dataset_name=train_dataset,
  97. namespace=train_dataset_namespace,
  98. version=train_dataset_revision)
  99. logger.info(f'train dataset:{train_dataset.config_kwargs}')
  100. self.raw_dataset_path = self.load_dataset_raw_path(
  101. train_dataset)
  102. else:
  103. self.raw_dataset_path = self.load_dataset_raw_path(
  104. train_dataset)
  105. if not model:
  106. raise TtsTrainingInvalidModelException('model param is none')
  107. if isinstance(model, str):
  108. model_dir = self.get_or_download_model_dir(model, model_revision)
  109. else:
  110. model_dir = model.model_dir
  111. shutil.copytree(model_dir, self.orig_model_dir)
  112. self.model_dir = self.orig_model_dir
  113. if not cfg_file:
  114. cfg_file = os.path.join(self.model_dir, ModelFile.CONFIGURATION)
  115. self.parse_cfg(cfg_file)
  116. if not os.path.exists(self.raw_dataset_path):
  117. raise TtsTrainingDatasetInvalidException(
  118. 'dataset raw path not exists')
  119. self.finetune_from_pretrain = False
  120. self.speaker = speaker
  121. self.model = None
  122. self.device = kwargs.get('device', 'gpu')
  123. self.model = self.get_model(self.model_dir, self.speaker)
  124. self.lang_type = self.model.lang_type
  125. if TtsTrainType.TRAIN_TYPE_SAMBERT in self.train_type or TtsTrainType.TRAIN_TYPE_VOC in self.train_type:
  126. self.audio_data_preprocessor = build_preprocessor(
  127. dict(type=Preprocessors.kantts_data_preprocessor),
  128. Tasks.text_to_speech)
  129. def parse_cfg(self, cfg_file):
  130. cur_dir = os.path.dirname(cfg_file)
  131. with open(cfg_file, 'r', encoding='utf-8') as f:
  132. config = json.load(f)
  133. if 'train' not in config:
  134. raise TtsTrainingInvalidModelException(
  135. 'model not support finetune')
  136. if 'audio_config' in config['train']:
  137. audio_config = os.path.join(cur_dir,
  138. config['train']['audio_config'])
  139. if os.path.exists(audio_config):
  140. self.audio_config_path = audio_config
  141. if 'am_config' in config['train']:
  142. am_config = os.path.join(cur_dir, config['train']['am_config'])
  143. if os.path.exists(am_config):
  144. self.am_config_path = am_config
  145. if 'voc_config' in config['train']:
  146. voc_config = os.path.join(cur_dir,
  147. config['train']['voc_config'])
  148. if os.path.exists(voc_config):
  149. self.voc_config_path = voc_config
  150. if not self.raw_dataset_path:
  151. if 'train_dataset' in config['train']:
  152. dataset = config['train']['train_dataset']
  153. if os.path.exists(dataset):
  154. self.raw_dataset_path = dataset
  155. else:
  156. if 'id' in dataset:
  157. namespace = dataset.get('namespace',
  158. DEFAULT_DATASET_NAMESPACE)
  159. revision = dataset.get('revision',
  160. DEFAULT_DATASET_REVISION)
  161. ms = MsDataset.load(
  162. dataset_name=dataset['id'],
  163. namespace=namespace,
  164. version=revision)
  165. self.raw_dataset_path = self.load_dataset_raw_path(
  166. ms)
  167. elif 'path' in dataset:
  168. self.raw_dataset_path = dataset['path']
  169. def load_dataset_raw_path(self, dataset: MsDataset):
  170. if 'split_config' not in dataset.config_kwargs:
  171. raise TtsTrainingDatasetInvalidException(
  172. 'split_config not found in config_kwargs')
  173. if 'train' not in dataset.config_kwargs['split_config']:
  174. raise TtsTrainingDatasetInvalidException(
  175. 'no train split in split_config')
  176. return dataset.config_kwargs['split_config']['train']
  177. def prepare_data(self):
  178. if self.audio_data_preprocessor:
  179. audio_config = self.audio_config_path
  180. if not audio_config or not os.path.exists(audio_config):
  181. audio_config = self.model.get_voice_audio_config_path(
  182. self.speaker)
  183. se_model = self.model.get_voice_se_model_path(self.speaker)
  184. self.audio_data_preprocessor(self.raw_dataset_path, self.data_dir,
  185. audio_config, self.speaker,
  186. self.lang_type, self.skip_script,
  187. se_model)
  188. def prepare_text(self):
  189. pass
  190. def get_model(self, model_dir, speaker):
  191. cfg = Config.from_file(
  192. os.path.join(self.model_dir, ModelFile.CONFIGURATION))
  193. model_cfg = cfg.get('model', {})
  194. model = SambertHifigan(
  195. model_dir=self.model_dir, is_train=True, **model_cfg)
  196. return model
  197. def train(self, *args, **kwargs):
  198. if not self.model:
  199. raise TtsTrainingInvalidModelException('model is none')
  200. ignore_pretrain = False
  201. if 'ignore_pretrain' in kwargs:
  202. ignore_pretrain = kwargs['ignore_pretrain']
  203. if TtsTrainType.TRAIN_TYPE_SAMBERT in self.train_type or TtsTrainType.TRAIN_TYPE_VOC in self.train_type:
  204. self.prepare_data()
  205. if TtsTrainType.TRAIN_TYPE_BERT in self.train_type:
  206. self.prepare_text()
  207. dir_dict = {
  208. 'work_dir': self.work_dir,
  209. 'am_tmp_dir': self.am_tmp_dir,
  210. 'voc_tmp_dir': self.voc_tmp_dir,
  211. 'data_dir': self.data_dir
  212. }
  213. config_dict = {
  214. 'am_config': self.am_config_path,
  215. 'voc_config': self.voc_config_path
  216. }
  217. self.model.train(self.speaker, dir_dict, self.train_type, config_dict,
  218. ignore_pretrain)
  219. def evaluate(self, checkpoint_path: str, *args,
  220. **kwargs) -> Dict[str, float]:
  221. return {}