asr_trainer.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import shutil
  4. import tempfile
  5. from typing import Dict, Optional, Union
  6. import json
  7. from funasr.bin import build_trainer
  8. from modelscope.metainfo import Trainers
  9. from modelscope.msdatasets import MsDataset
  10. from modelscope.trainers.base import BaseTrainer
  11. from modelscope.trainers.builder import TRAINERS
  12. from modelscope.utils.constant import (DEFAULT_DATASET_NAMESPACE,
  13. DEFAULT_DATASET_REVISION,
  14. DEFAULT_MODEL_REVISION, ModelFile,
  15. Tasks, TrainerStages)
  16. from modelscope.utils.logger import get_logger
  17. logger = get_logger()
  18. @TRAINERS.register_module(module_name=Trainers.speech_asr_trainer)
  19. class ASRTrainer(BaseTrainer):
  20. DATA_DIR = 'data'
  21. def __init__(self,
  22. model: str,
  23. work_dir: str = None,
  24. distributed: bool = False,
  25. dataset_type: str = 'small',
  26. data_dir: Optional[Union[MsDataset, str]] = None,
  27. model_revision: Optional[str] = DEFAULT_MODEL_REVISION,
  28. batch_bins: Optional[int] = None,
  29. max_epoch: Optional[int] = None,
  30. lr: Optional[float] = None,
  31. mate_params: Optional[dict] = None,
  32. **kwargs):
  33. """ASR Trainer.
  34. Args:
  35. model (str) : model name
  36. work_dir (str): output dir for saving results
  37. distributed (bool): whether to enable DDP training
  38. dataset_type (str): choose which dataset type to use
  39. data_dir (str): the path of data
  40. model_revision (str): set model version
  41. batch_bins (str): batch size
  42. max_epoch (int): the maximum epoch number for training
  43. lr (float): learning rate
  44. mate_params (dict): for saving other training args
  45. Examples:
  46. >>> import os
  47. >>> from modelscope.metainfo import Trainers
  48. >>> from modelscope.msdatasets import MsDataset
  49. >>> from modelscope.trainers import build_trainer
  50. >>> ds_dict = MsDataset.load('speech_asr_aishell1_trainsets')
  51. >>> kwargs = dict(
  52. >>> model='damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch',
  53. >>> data_dir=ds_dict,
  54. >>> work_dir="./checkpoint")
  55. >>> trainer = build_trainer(
  56. >>> Trainers.speech_asr_trainer, default_args=kwargs)
  57. >>> trainer.train()
  58. """
  59. if not work_dir:
  60. self.work_dir = tempfile.TemporaryDirectory().name
  61. if not os.path.exists(self.work_dir):
  62. os.makedirs(self.work_dir)
  63. else:
  64. self.work_dir = work_dir
  65. if not os.path.exists(self.work_dir):
  66. raise Exception(f'{self.work_dir} not exists')
  67. logger.info(f'Set workdir to {self.work_dir}')
  68. self.data_dir = os.path.join(self.work_dir, self.DATA_DIR)
  69. self.raw_dataset_path = ''
  70. self.distributed = distributed
  71. self.dataset_type = dataset_type
  72. shutil.rmtree(self.data_dir, ignore_errors=True)
  73. os.makedirs(self.data_dir, exist_ok=True)
  74. if os.path.exists(model):
  75. model_dir = model
  76. else:
  77. model_dir = self.get_or_download_model_dir(model, model_revision)
  78. self.model_dir = model_dir
  79. self.model_cfg = os.path.join(self.model_dir, 'configuration.json')
  80. self.cfg_dict = self.parse_cfg(self.model_cfg)
  81. if 'raw_data_dir' not in data_dir:
  82. self.train_data_dir, self.dev_data_dir = self.load_dataset_raw_path(
  83. data_dir, self.data_dir)
  84. else:
  85. self.data_dir = data_dir['raw_data_dir']
  86. self.trainer = build_trainer.build_trainer(
  87. modelscope_dict=self.cfg_dict,
  88. data_dir=self.data_dir,
  89. output_dir=self.work_dir,
  90. distributed=self.distributed,
  91. dataset_type=self.dataset_type,
  92. batch_bins=batch_bins,
  93. max_epoch=max_epoch,
  94. lr=lr,
  95. mate_params=mate_params)
  96. def parse_cfg(self, cfg_file):
  97. cur_dir = os.path.dirname(cfg_file)
  98. cfg_dict = dict()
  99. with open(cfg_file, 'r', encoding='utf-8') as f:
  100. config = json.load(f)
  101. cfg_dict['mode'] = config['model']['model_config']['mode']
  102. cfg_dict['model_dir'] = cur_dir
  103. cfg_dict['am_model_file'] = os.path.join(
  104. cur_dir, config['model']['am_model_name'])
  105. cfg_dict['am_model_config'] = os.path.join(
  106. cur_dir, config['model']['model_config']['am_model_config'])
  107. cfg_dict['finetune_config'] = os.path.join(cur_dir,
  108. 'finetune.yaml')
  109. cfg_dict['cmvn_file'] = os.path.join(
  110. cur_dir, config['model']['model_config']['mvn_file'])
  111. cfg_dict['seg_dict'] = os.path.join(cur_dir, 'seg_dict')
  112. if 'bpemodel' in config['model']['model_config']:
  113. cfg_dict['bpemodel'] = os.path.join(
  114. cur_dir, config['model']['model_config']['bpemodel'])
  115. else:
  116. cfg_dict['bpemodel'] = None
  117. if 'init_model' in config['model']['model_config']:
  118. cfg_dict['init_model'] = os.path.join(
  119. cur_dir, config['model']['model_config']['init_model'])
  120. else:
  121. cfg_dict['init_model'] = cfg_dict['am_model_file']
  122. return cfg_dict
  123. def load_dataset_raw_path(self, dataset, output_data_dir):
  124. if 'train' not in dataset:
  125. raise Exception(
  126. 'dataset {0} does not contain a train split'.format(dataset))
  127. train_data_dir = self.prepare_data(
  128. dataset, output_data_dir, split='train')
  129. if 'validation' not in dataset:
  130. raise Exception(
  131. 'dataset {0} does not contain a dev split'.format(dataset))
  132. dev_data_dir = self.prepare_data(
  133. dataset, output_data_dir, split='validation')
  134. return train_data_dir, dev_data_dir
  135. def prepare_data(self, dataset, out_base_dir, split='train'):
  136. out_dir = os.path.join(out_base_dir, split)
  137. shutil.rmtree(out_dir, ignore_errors=True)
  138. os.makedirs(out_dir, exist_ok=True)
  139. data_cnt = len(dataset[split])
  140. fp_wav_scp = open(os.path.join(out_dir, 'wav.scp'), 'w')
  141. fp_text = open(os.path.join(out_dir, 'text'), 'w')
  142. for i in range(data_cnt):
  143. content = dataset[split][i]
  144. wav_file = content['Audio:FILE']
  145. text = content['Text:LABEL']
  146. fp_wav_scp.write('\t'.join([os.path.basename(wav_file), wav_file])
  147. + '\n')
  148. fp_text.write('\t'.join([os.path.basename(wav_file), text]) + '\n')
  149. fp_text.close()
  150. fp_wav_scp.close()
  151. return out_dir
  152. def train(self, *args, **kwargs):
  153. self.trainer.run()
  154. def evaluate(self, checkpoint_path: str, *args,
  155. **kwargs) -> Dict[str, float]:
  156. raise NotImplementedError