| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- import os
- import shutil
- import tempfile
- import zipfile
- from typing import Callable, Dict, List, Optional, Tuple, Union
- import json
- from modelscope.metainfo import Preprocessors, Trainers
- from modelscope.models import Model
- from modelscope.models.audio.tts import SambertHifigan
- from modelscope.msdatasets import MsDataset
- from modelscope.preprocessors.builder import build_preprocessor
- from modelscope.trainers.base import BaseTrainer
- from modelscope.trainers.builder import TRAINERS
- from modelscope.utils.audio.audio_utils import TtsTrainType
- from modelscope.utils.audio.tts_exceptions import (
- TtsTrainingCfgNotExistsException, TtsTrainingDatasetInvalidException,
- TtsTrainingHparamsInvalidException, TtsTrainingInvalidModelException,
- TtsTrainingWorkDirNotExistsException)
- from modelscope.utils.config import Config
- from modelscope.utils.constant import (DEFAULT_DATASET_NAMESPACE,
- DEFAULT_DATASET_REVISION,
- DEFAULT_MODEL_REVISION, ModelFile,
- Tasks, TrainerStages)
- from modelscope.utils.data_utils import to_device
- from modelscope.utils.logger import get_logger
- logger = get_logger()
- @TRAINERS.register_module(module_name=Trainers.speech_kantts_trainer)
- class KanttsTrainer(BaseTrainer):
- DATA_DIR = 'data'
- AM_TMP_DIR = 'tmp_am'
- VOC_TMP_DIR = 'tmp_voc'
- ORIG_MODEL_DIR = 'orig_model'
- def __init__(self,
- model: Union[Model, str],
- work_dir: str = None,
- speaker: str = 'F7',
- lang_type: str = 'PinYin',
- cfg_file: str = None,
- train_dataset: Union[MsDataset, str] = None,
- train_dataset_namespace: str = DEFAULT_DATASET_NAMESPACE,
- train_dataset_revision: str = DEFAULT_DATASET_REVISION,
- train_type: dict = {
- TtsTrainType.TRAIN_TYPE_SAMBERT: {},
- TtsTrainType.TRAIN_TYPE_VOC: {}
- },
- preprocess_skip_script=False,
- model_revision: Optional[str] = DEFAULT_MODEL_REVISION,
- **kwargs):
- if not work_dir:
- self.work_dir = tempfile.TemporaryDirectory().name
- if not os.path.exists(self.work_dir):
- os.makedirs(self.work_dir)
- else:
- self.work_dir = work_dir
- if not os.path.exists(self.work_dir):
- raise TtsTrainingWorkDirNotExistsException(
- f'{self.work_dir} not exists')
- self.train_type = dict()
- if isinstance(train_type, dict):
- for k, v in train_type.items():
- if (k == TtsTrainType.TRAIN_TYPE_SAMBERT
- or k == TtsTrainType.TRAIN_TYPE_VOC
- or k == TtsTrainType.TRAIN_TYPE_BERT):
- self.train_type[k] = v
- if len(self.train_type) == 0:
- logger.info('train type empty, default to sambert and voc')
- self.train_type[TtsTrainType.TRAIN_TYPE_SAMBERT] = {}
- self.train_type[TtsTrainType.TRAIN_TYPE_VOC] = {}
- logger.info(f'Set workdir to {self.work_dir}')
- self.data_dir = os.path.join(self.work_dir, self.DATA_DIR)
- self.am_tmp_dir = os.path.join(self.work_dir, self.AM_TMP_DIR)
- self.voc_tmp_dir = os.path.join(self.work_dir, self.VOC_TMP_DIR)
- self.orig_model_dir = os.path.join(self.work_dir, self.ORIG_MODEL_DIR)
- self.raw_dataset_path = ''
- self.skip_script = preprocess_skip_script
- self.audio_config_path = ''
- self.am_config_path = ''
- self.voc_config_path = ''
- shutil.rmtree(self.data_dir, ignore_errors=True)
- shutil.rmtree(self.am_tmp_dir, ignore_errors=True)
- shutil.rmtree(self.voc_tmp_dir, ignore_errors=True)
- shutil.rmtree(self.orig_model_dir, ignore_errors=True)
- os.makedirs(self.data_dir)
- os.makedirs(self.am_tmp_dir)
- os.makedirs(self.voc_tmp_dir)
- if train_dataset:
- if isinstance(train_dataset, str):
- if os.path.exists(train_dataset):
- logger.info(f'load {train_dataset}')
- self.raw_dataset_path = train_dataset
- else:
- logger.info(
- f'load {train_dataset_namespace}/{train_dataset}')
- train_dataset = MsDataset.load(
- dataset_name=train_dataset,
- namespace=train_dataset_namespace,
- version=train_dataset_revision)
- logger.info(f'train dataset:{train_dataset.config_kwargs}')
- self.raw_dataset_path = self.load_dataset_raw_path(
- train_dataset)
- else:
- self.raw_dataset_path = self.load_dataset_raw_path(
- train_dataset)
- if not model:
- raise TtsTrainingInvalidModelException('model param is none')
- if isinstance(model, str):
- model_dir = self.get_or_download_model_dir(model, model_revision)
- else:
- model_dir = model.model_dir
- shutil.copytree(model_dir, self.orig_model_dir)
- self.model_dir = self.orig_model_dir
- if not cfg_file:
- cfg_file = os.path.join(self.model_dir, ModelFile.CONFIGURATION)
- self.parse_cfg(cfg_file)
- if not os.path.exists(self.raw_dataset_path):
- raise TtsTrainingDatasetInvalidException(
- 'dataset raw path not exists')
- self.finetune_from_pretrain = False
- self.speaker = speaker
- self.model = None
- self.device = kwargs.get('device', 'gpu')
- self.model = self.get_model(self.model_dir, self.speaker)
- self.lang_type = self.model.lang_type
- if TtsTrainType.TRAIN_TYPE_SAMBERT in self.train_type or TtsTrainType.TRAIN_TYPE_VOC in self.train_type:
- self.audio_data_preprocessor = build_preprocessor(
- dict(type=Preprocessors.kantts_data_preprocessor),
- Tasks.text_to_speech)
- def parse_cfg(self, cfg_file):
- cur_dir = os.path.dirname(cfg_file)
- with open(cfg_file, 'r', encoding='utf-8') as f:
- config = json.load(f)
- if 'train' not in config:
- raise TtsTrainingInvalidModelException(
- 'model not support finetune')
- if 'audio_config' in config['train']:
- audio_config = os.path.join(cur_dir,
- config['train']['audio_config'])
- if os.path.exists(audio_config):
- self.audio_config_path = audio_config
- if 'am_config' in config['train']:
- am_config = os.path.join(cur_dir, config['train']['am_config'])
- if os.path.exists(am_config):
- self.am_config_path = am_config
- if 'voc_config' in config['train']:
- voc_config = os.path.join(cur_dir,
- config['train']['voc_config'])
- if os.path.exists(voc_config):
- self.voc_config_path = voc_config
- if not self.raw_dataset_path:
- if 'train_dataset' in config['train']:
- dataset = config['train']['train_dataset']
- if os.path.exists(dataset):
- self.raw_dataset_path = dataset
- else:
- if 'id' in dataset:
- namespace = dataset.get('namespace',
- DEFAULT_DATASET_NAMESPACE)
- revision = dataset.get('revision',
- DEFAULT_DATASET_REVISION)
- ms = MsDataset.load(
- dataset_name=dataset['id'],
- namespace=namespace,
- version=revision)
- self.raw_dataset_path = self.load_dataset_raw_path(
- ms)
- elif 'path' in dataset:
- self.raw_dataset_path = dataset['path']
- def load_dataset_raw_path(self, dataset: MsDataset):
- if 'split_config' not in dataset.config_kwargs:
- raise TtsTrainingDatasetInvalidException(
- 'split_config not found in config_kwargs')
- if 'train' not in dataset.config_kwargs['split_config']:
- raise TtsTrainingDatasetInvalidException(
- 'no train split in split_config')
- return dataset.config_kwargs['split_config']['train']
- def prepare_data(self):
- if self.audio_data_preprocessor:
- audio_config = self.audio_config_path
- if not audio_config or not os.path.exists(audio_config):
- audio_config = self.model.get_voice_audio_config_path(
- self.speaker)
- se_model = self.model.get_voice_se_model_path(self.speaker)
- self.audio_data_preprocessor(self.raw_dataset_path, self.data_dir,
- audio_config, self.speaker,
- self.lang_type, self.skip_script,
- se_model)
- def prepare_text(self):
- pass
- def get_model(self, model_dir, speaker):
- cfg = Config.from_file(
- os.path.join(self.model_dir, ModelFile.CONFIGURATION))
- model_cfg = cfg.get('model', {})
- model = SambertHifigan(
- model_dir=self.model_dir, is_train=True, **model_cfg)
- return model
- def train(self, *args, **kwargs):
- if not self.model:
- raise TtsTrainingInvalidModelException('model is none')
- ignore_pretrain = False
- if 'ignore_pretrain' in kwargs:
- ignore_pretrain = kwargs['ignore_pretrain']
- if TtsTrainType.TRAIN_TYPE_SAMBERT in self.train_type or TtsTrainType.TRAIN_TYPE_VOC in self.train_type:
- self.prepare_data()
- if TtsTrainType.TRAIN_TYPE_BERT in self.train_type:
- self.prepare_text()
- dir_dict = {
- 'work_dir': self.work_dir,
- 'am_tmp_dir': self.am_tmp_dir,
- 'voc_tmp_dir': self.voc_tmp_dir,
- 'data_dir': self.data_dir
- }
- config_dict = {
- 'am_config': self.am_config_path,
- 'voc_config': self.voc_config_path
- }
- self.model.train(self.speaker, dir_dict, self.train_type, config_dict,
- ignore_pretrain)
- def evaluate(self, checkpoint_path: str, *args,
- **kwargs) -> Dict[str, float]:
- return {}
|