kws_farfield_trainer.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314
  1. import datetime
  2. import glob
  3. import math
  4. import os
  5. import pickle
  6. from typing import Callable, Dict, Optional
  7. import numpy as np
  8. import torch
  9. from torch import nn as nn
  10. from torch import optim as optim
  11. from modelscope.metainfo import Trainers
  12. from modelscope.models import Model, TorchModel
  13. from modelscope.msdatasets.dataset_cls.custom_datasets.audio import (
  14. KWSDataLoader, KWSDataset)
  15. from modelscope.trainers.base import BaseTrainer
  16. from modelscope.trainers.builder import TRAINERS
  17. from modelscope.utils.audio.audio_utils import update_conf
  18. from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile
  19. from modelscope.utils.data_utils import to_device
  20. from modelscope.utils.device import create_device
  21. from modelscope.utils.logger import get_logger
  22. from modelscope.utils.torch_utils import (get_dist_info, get_local_rank,
  23. init_dist, is_master)
  24. logger = get_logger()
  25. BASETRAIN_CONF_EASY = 'basetrain_easy'
  26. BASETRAIN_CONF_NORMAL = 'basetrain_normal'
  27. BASETRAIN_CONF_HARD = 'basetrain_hard'
  28. FINETUNE_CONF_EASY = 'finetune_easy'
  29. FINETUNE_CONF_NORMAL = 'finetune_normal'
  30. FINETUNE_CONF_HARD = 'finetune_hard'
  31. CKPT_PREFIX = 'checkpoint'
  32. EASY_RATIO = 0.1
  33. NORMAL_RATIO = 0.6
  34. HARD_RATIO = 0.3
  35. BASETRAIN_RATIO = 0.5
  36. @TRAINERS.register_module(module_name=Trainers.speech_dfsmn_kws_char_farfield)
  37. class KWSFarfieldTrainer(BaseTrainer):
  38. DEFAULT_WORK_DIR = './work_dir'
  39. conf_keys = (BASETRAIN_CONF_EASY, FINETUNE_CONF_EASY,
  40. BASETRAIN_CONF_NORMAL, FINETUNE_CONF_NORMAL,
  41. BASETRAIN_CONF_HARD, FINETUNE_CONF_HARD)
  42. def __init__(self,
  43. model: str,
  44. work_dir: str,
  45. cfg_file: Optional[str] = None,
  46. arg_parse_fn: Optional[Callable] = None,
  47. model_revision: Optional[str] = DEFAULT_MODEL_REVISION,
  48. custom_conf: Optional[dict] = None,
  49. **kwargs):
  50. if isinstance(model, str):
  51. self.model_dir = self.get_or_download_model_dir(
  52. model, model_revision)
  53. if cfg_file is None:
  54. cfg_file = os.path.join(self.model_dir,
  55. ModelFile.CONFIGURATION)
  56. else:
  57. assert cfg_file is not None, 'Config file should not be None if model is not from pretrained!'
  58. self.model_dir = os.path.dirname(cfg_file)
  59. super().__init__(cfg_file, arg_parse_fn)
  60. # the number of model output dimension
  61. # should update config outside the trainer, if user need more wake word
  62. num_syn = kwargs.get('num_syn', None)
  63. if num_syn:
  64. self.cfg.model.num_syn = num_syn
  65. self._num_classes = self.cfg.model.num_syn
  66. self.model = self.build_model()
  67. self.work_dir = work_dir
  68. if kwargs.get('launcher', None) is not None:
  69. init_dist(kwargs['launcher'])
  70. _, world_size = get_dist_info()
  71. self._dist = world_size > 1
  72. device_name = kwargs.get('device', 'gpu')
  73. if self._dist:
  74. local_rank = get_local_rank()
  75. device_name = f'cuda:{local_rank}'
  76. self.device = create_device(device_name)
  77. # model placement
  78. if self.device.type == 'cuda':
  79. self.model.to(self.device)
  80. if 'max_epochs' not in kwargs:
  81. assert hasattr(
  82. self.cfg.train, 'max_epochs'
  83. ), 'max_epochs is missing from the configuration file'
  84. self._max_epochs = self.cfg.train.max_epochs
  85. else:
  86. self._max_epochs = kwargs['max_epochs']
  87. self._train_iters = kwargs.get('train_iters_per_epoch', None)
  88. self._val_iters = kwargs.get('val_iters_per_epoch', None)
  89. if self._train_iters is None:
  90. self._train_iters = self.cfg.train.train_iters_per_epoch
  91. if self._val_iters is None:
  92. self._val_iters = self.cfg.evaluation.val_iters_per_epoch
  93. dataloader_config = self.cfg.train.dataloader
  94. self._threads = kwargs.get('workers', None)
  95. if self._threads is None:
  96. self._threads = dataloader_config.workers_per_gpu
  97. self._single_rate = BASETRAIN_RATIO
  98. if 'single_rate' in kwargs:
  99. self._single_rate = kwargs['single_rate']
  100. self._batch_size = dataloader_config.batch_size_per_gpu
  101. next_epoch = kwargs.get('next_epoch', 1)
  102. self._current_epoch = next_epoch - 1
  103. if 'model_bin' in kwargs:
  104. model_bin_file = os.path.join(self.model_dir, kwargs['model_bin'])
  105. self.model = torch.load(model_bin_file, weights_only=True)
  106. elif self._current_epoch > 0:
  107. # load checkpoint
  108. ckpt_file_pattern = os.path.join(
  109. self.work_dir, f'{CKPT_PREFIX}_{self._current_epoch:04d}*.pth')
  110. ckpt_files = glob.glob(ckpt_file_pattern)
  111. if len(ckpt_files) == 1:
  112. logger.info('Loading model from checkpoint: %s', ckpt_files[0])
  113. self.model = torch.load(ckpt_files[0], weights_only=True)
  114. elif len(ckpt_files) == 0:
  115. raise FileNotFoundError(
  116. f'Failed to load checkpoint file like '
  117. f'{ckpt_file_pattern}. File not found!')
  118. else:
  119. raise AssertionError(f'Expecting one but multiple checkpoint'
  120. f' files are found: {ckpt_files}')
  121. # build corresponding optimizer and loss function
  122. lr = self.cfg.train.optimizer.lr
  123. self.optimizer = optim.Adam(self.model.parameters(), lr)
  124. self.loss_fn = nn.CrossEntropyLoss()
  125. self.data_val = None
  126. self.json_log_path = os.path.join(self.work_dir,
  127. '{}.log.json'.format(self.timestamp))
  128. self.conf_files = []
  129. for conf_key in self.conf_keys:
  130. template_file = os.path.join(self.model_dir, conf_key)
  131. conf_file = os.path.join(self.work_dir, f'{conf_key}.conf')
  132. update_conf(template_file, conf_file, custom_conf[conf_key])
  133. self.conf_files.append(conf_file)
  134. self.stages = (math.floor(self._max_epochs * EASY_RATIO),
  135. math.floor(self._max_epochs * NORMAL_RATIO),
  136. math.floor(self._max_epochs * HARD_RATIO))
  137. def build_model(self) -> nn.Module:
  138. """ Instantiate a pytorch model and return.
  139. By default, we will create a model using config from configuration file. You can
  140. override this method in a subclass.
  141. """
  142. model = Model.from_pretrained(
  143. self.model_dir, cfg_dict=self.cfg, training=True)
  144. if isinstance(model, TorchModel) and hasattr(model, 'model'):
  145. return model.model
  146. elif isinstance(model, nn.Module):
  147. return model
  148. def train(self, *args, **kwargs):
  149. if not self.data_val:
  150. self.gen_val()
  151. logger.info('Start training...')
  152. totaltime = datetime.datetime.now()
  153. next_stage_head_epoch = 0
  154. for stage, num_epoch in enumerate(self.stages):
  155. next_stage_head_epoch += num_epoch
  156. epochs_to_run = next_stage_head_epoch - self._current_epoch
  157. self.run_stage(stage, epochs_to_run)
  158. # total time spent
  159. totaltime = datetime.datetime.now() - totaltime
  160. logger.info('Total time spent: {:.2f} hours\n'.format(
  161. totaltime.total_seconds() / 3600.0))
  162. def run_stage(self, stage, epochs_to_run):
  163. """
  164. Run training stages with correspond data
  165. Args:
  166. stage: id of stage
  167. epochs_to_run: the number of epoch to run in this stage
  168. """
  169. if epochs_to_run <= 0:
  170. logger.warning(f'Invalid epoch number, stage {stage} exit!')
  171. return
  172. logger.info(f'Starting stage {stage}...')
  173. dataset, dataloader = self.create_dataloader(
  174. self.conf_files[stage * 2], self.conf_files[stage * 2 + 1])
  175. it = iter(dataloader)
  176. for _ in range(epochs_to_run):
  177. self._current_epoch += 1
  178. epochtime = datetime.datetime.now()
  179. logger.info('Start epoch %d...', self._current_epoch)
  180. loss_train_epoch = 0.0
  181. validbatchs = 0
  182. for bi in range(self._train_iters):
  183. # prepare data
  184. feat, label = next(it)
  185. label = torch.reshape(label, (-1, ))
  186. feat = to_device(feat, self.device)
  187. label = to_device(label, self.device)
  188. # apply model
  189. self.optimizer.zero_grad()
  190. predict = self.model(feat)
  191. # calculate loss
  192. loss = self.loss_fn(
  193. torch.reshape(predict, (-1, self._num_classes)), label)
  194. if not np.isnan(loss.item()):
  195. loss.backward()
  196. self.optimizer.step()
  197. loss_train_epoch += loss.item()
  198. validbatchs += 1
  199. train_result = 'Epoch: {:04d}/{:04d}, batch: {:04d}/{:04d}, loss: {:.4f}'.format(
  200. self._current_epoch, self._max_epochs, bi + 1,
  201. self._train_iters, loss.item())
  202. logger.info(train_result)
  203. self._dump_log(train_result)
  204. # average training loss in one epoch
  205. loss_train_epoch /= validbatchs
  206. loss_val_epoch = self.evaluate('')
  207. val_result = 'Evaluate epoch: {:04d}, loss_train: {:.4f}, loss_val: {:.4f}'.format(
  208. self._current_epoch, loss_train_epoch, loss_val_epoch)
  209. logger.info(val_result)
  210. self._dump_log(val_result)
  211. # check point
  212. ckpt_name = '{}_{:04d}_loss_train_{:.4f}_loss_val_{:.4f}.pth'.format(
  213. CKPT_PREFIX, self._current_epoch, loss_train_epoch,
  214. loss_val_epoch)
  215. save_path = os.path.join(self.work_dir, ckpt_name)
  216. logger.info(f'Save model to {save_path}')
  217. torch.save(self.model, save_path)
  218. # time spent per epoch
  219. epochtime = datetime.datetime.now() - epochtime
  220. logger.info('Epoch {:04d} time spent: {:.2f} hours'.format(
  221. self._current_epoch,
  222. epochtime.total_seconds() / 3600.0))
  223. dataloader.stop()
  224. dataset.release()
  225. logger.info(f'Stage {stage} is finished.')
  226. def gen_val(self):
  227. """
  228. generate validation set
  229. """
  230. val_dump_file = os.path.join(self.work_dir, 'val_dataset.bin')
  231. if self._current_epoch > 0:
  232. logger.info('Start loading validation set...')
  233. with open(val_dump_file, 'rb') as f:
  234. self.data_val = pickle.load(f)
  235. logger.info('Finish loading validation set!')
  236. return
  237. logger.info('Start generating validation set...')
  238. dataset, dataloader = self.create_dataloader(self.conf_files[2],
  239. self.conf_files[3])
  240. it = iter(dataloader)
  241. self.data_val = []
  242. for bi in range(self._val_iters):
  243. logger.info('Iterating validation data %d', bi)
  244. feat, label = next(it)
  245. label = torch.reshape(label, (-1, ))
  246. self.data_val.append([feat, label])
  247. dataloader.stop()
  248. dataset.release()
  249. with open(val_dump_file, 'wb') as f:
  250. pickle.dump(self.data_val, f)
  251. logger.info('Finish generating validation set!')
  252. def create_dataloader(self, base_path, finetune_path):
  253. dataset = KWSDataset(base_path, finetune_path, self._threads,
  254. self._single_rate, self._num_classes)
  255. dataloader = KWSDataLoader(
  256. dataset, batchsize=self._batch_size, numworkers=self._threads)
  257. dataloader.start()
  258. return dataset, dataloader
  259. def evaluate(self, checkpoint_path: str, *args,
  260. **kwargs) -> Dict[str, float]:
  261. logger.info('Start validation...')
  262. loss_val_epoch = 0.0
  263. with torch.no_grad():
  264. for feat, label in self.data_val:
  265. feat = to_device(feat, self.device)
  266. label = to_device(label, self.device)
  267. # apply model
  268. predict = self.model(feat)
  269. # calculate loss
  270. loss = self.loss_fn(
  271. torch.reshape(predict, (-1, self._num_classes)), label)
  272. loss_val_epoch += loss.item()
  273. logger.info('Finish validation.')
  274. return loss_val_epoch / self._val_iters
  275. def _dump_log(self, msg):
  276. if is_master():
  277. with open(self.json_log_path, 'a+') as f:
  278. f.write(msg)
  279. f.write('\n')