| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- import os
- import shutil
- import tempfile
- from typing import Dict, Optional, Union
- import json
- from funasr.bin import build_trainer
- from modelscope.metainfo import Trainers
- from modelscope.msdatasets import MsDataset
- from modelscope.trainers.base import BaseTrainer
- from modelscope.trainers.builder import TRAINERS
- from modelscope.utils.constant import (DEFAULT_DATASET_NAMESPACE,
- DEFAULT_DATASET_REVISION,
- DEFAULT_MODEL_REVISION, ModelFile,
- Tasks, TrainerStages)
- from modelscope.utils.logger import get_logger
- logger = get_logger()
- @TRAINERS.register_module(module_name=Trainers.speech_asr_trainer)
- class ASRTrainer(BaseTrainer):
- DATA_DIR = 'data'
- def __init__(self,
- model: str,
- work_dir: str = None,
- distributed: bool = False,
- dataset_type: str = 'small',
- data_dir: Optional[Union[MsDataset, str]] = None,
- model_revision: Optional[str] = DEFAULT_MODEL_REVISION,
- batch_bins: Optional[int] = None,
- max_epoch: Optional[int] = None,
- lr: Optional[float] = None,
- mate_params: Optional[dict] = None,
- **kwargs):
- """ASR Trainer.
- Args:
- model (str) : model name
- work_dir (str): output dir for saving results
- distributed (bool): whether to enable DDP training
- dataset_type (str): choose which dataset type to use
- data_dir (str): the path of data
- model_revision (str): set model version
- batch_bins (str): batch size
- max_epoch (int): the maximum epoch number for training
- lr (float): learning rate
- mate_params (dict): for saving other training args
- Examples:
- >>> import os
- >>> from modelscope.metainfo import Trainers
- >>> from modelscope.msdatasets import MsDataset
- >>> from modelscope.trainers import build_trainer
- >>> ds_dict = MsDataset.load('speech_asr_aishell1_trainsets')
- >>> kwargs = dict(
- >>> model='damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch',
- >>> data_dir=ds_dict,
- >>> work_dir="./checkpoint")
- >>> trainer = build_trainer(
- >>> Trainers.speech_asr_trainer, default_args=kwargs)
- >>> trainer.train()
- """
- if not work_dir:
- self.work_dir = tempfile.TemporaryDirectory().name
- if not os.path.exists(self.work_dir):
- os.makedirs(self.work_dir)
- else:
- self.work_dir = work_dir
- if not os.path.exists(self.work_dir):
- raise Exception(f'{self.work_dir} not exists')
- logger.info(f'Set workdir to {self.work_dir}')
- self.data_dir = os.path.join(self.work_dir, self.DATA_DIR)
- self.raw_dataset_path = ''
- self.distributed = distributed
- self.dataset_type = dataset_type
- shutil.rmtree(self.data_dir, ignore_errors=True)
- os.makedirs(self.data_dir, exist_ok=True)
- if os.path.exists(model):
- model_dir = model
- else:
- model_dir = self.get_or_download_model_dir(model, model_revision)
- self.model_dir = model_dir
- self.model_cfg = os.path.join(self.model_dir, 'configuration.json')
- self.cfg_dict = self.parse_cfg(self.model_cfg)
- if 'raw_data_dir' not in data_dir:
- self.train_data_dir, self.dev_data_dir = self.load_dataset_raw_path(
- data_dir, self.data_dir)
- else:
- self.data_dir = data_dir['raw_data_dir']
- self.trainer = build_trainer.build_trainer(
- modelscope_dict=self.cfg_dict,
- data_dir=self.data_dir,
- output_dir=self.work_dir,
- distributed=self.distributed,
- dataset_type=self.dataset_type,
- batch_bins=batch_bins,
- max_epoch=max_epoch,
- lr=lr,
- mate_params=mate_params)
- def parse_cfg(self, cfg_file):
- cur_dir = os.path.dirname(cfg_file)
- cfg_dict = dict()
- with open(cfg_file, 'r', encoding='utf-8') as f:
- config = json.load(f)
- cfg_dict['mode'] = config['model']['model_config']['mode']
- cfg_dict['model_dir'] = cur_dir
- cfg_dict['am_model_file'] = os.path.join(
- cur_dir, config['model']['am_model_name'])
- cfg_dict['am_model_config'] = os.path.join(
- cur_dir, config['model']['model_config']['am_model_config'])
- cfg_dict['finetune_config'] = os.path.join(cur_dir,
- 'finetune.yaml')
- cfg_dict['cmvn_file'] = os.path.join(
- cur_dir, config['model']['model_config']['mvn_file'])
- cfg_dict['seg_dict'] = os.path.join(cur_dir, 'seg_dict')
- if 'bpemodel' in config['model']['model_config']:
- cfg_dict['bpemodel'] = os.path.join(
- cur_dir, config['model']['model_config']['bpemodel'])
- else:
- cfg_dict['bpemodel'] = None
- if 'init_model' in config['model']['model_config']:
- cfg_dict['init_model'] = os.path.join(
- cur_dir, config['model']['model_config']['init_model'])
- else:
- cfg_dict['init_model'] = cfg_dict['am_model_file']
- return cfg_dict
- def load_dataset_raw_path(self, dataset, output_data_dir):
- if 'train' not in dataset:
- raise Exception(
- 'dataset {0} does not contain a train split'.format(dataset))
- train_data_dir = self.prepare_data(
- dataset, output_data_dir, split='train')
- if 'validation' not in dataset:
- raise Exception(
- 'dataset {0} does not contain a dev split'.format(dataset))
- dev_data_dir = self.prepare_data(
- dataset, output_data_dir, split='validation')
- return train_data_dir, dev_data_dir
- def prepare_data(self, dataset, out_base_dir, split='train'):
- out_dir = os.path.join(out_base_dir, split)
- shutil.rmtree(out_dir, ignore_errors=True)
- os.makedirs(out_dir, exist_ok=True)
- data_cnt = len(dataset[split])
- fp_wav_scp = open(os.path.join(out_dir, 'wav.scp'), 'w')
- fp_text = open(os.path.join(out_dir, 'text'), 'w')
- for i in range(data_cnt):
- content = dataset[split][i]
- wav_file = content['Audio:FILE']
- text = content['Text:LABEL']
- fp_wav_scp.write('\t'.join([os.path.basename(wav_file), wav_file])
- + '\n')
- fp_text.write('\t'.join([os.path.basename(wav_file), text]) + '\n')
- fp_text.close()
- fp_wav_scp.close()
- return out_dir
- def train(self, *args, **kwargs):
- self.trainer.run()
- def evaluate(self, checkpoint_path: str, *args,
- **kwargs) -> Dict[str, float]:
- raise NotImplementedError
|