translation_evaluation_trainer.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. """PyTorch trainer for UniTE model."""
  3. import os.path as osp
  4. import random
  5. from math import ceil
  6. from os import mkdir
  7. from typing import Any, Callable, Dict, List, Optional, Tuple, Union
  8. import torch
  9. from pandas import DataFrame
  10. from torch.nn.functional import pad
  11. from torch.nn.utils import clip_grad_norm_
  12. from torch.optim import AdamW, Optimizer
  13. from torch.utils.data import (BatchSampler, DataLoader, Dataset, Sampler,
  14. SequentialSampler, SubsetRandomSampler)
  15. from torch.utils.tensorboard import SummaryWriter
  16. from tqdm import tqdm
  17. from transformers import AutoTokenizer
  18. from modelscope.metainfo import Metrics, Trainers
  19. from modelscope.metrics import Metric
  20. from modelscope.metrics.builder import MetricKeys, build_metric
  21. from modelscope.models.base import TorchModel
  22. from modelscope.models.nlp.unite.configuration import InputFormat
  23. from modelscope.models.nlp.unite.translation_evaluation import (
  24. UniTEForTranslationEvaluation, combine_input_sentences)
  25. from modelscope.msdatasets import MsDataset
  26. from modelscope.preprocessors import Preprocessor
  27. from modelscope.trainers.builder import TRAINERS
  28. from modelscope.trainers.hooks import Hook
  29. from modelscope.trainers.trainer import EpochBasedTrainer
  30. from modelscope.utils.config import ConfigDict
  31. from modelscope.utils.constant import (ConfigKeys, Fields, ModeKeys, ModelFile,
  32. TrainerStages)
  33. from modelscope.utils.device import create_device
  34. from modelscope.utils.logger import get_logger
  35. logger = get_logger()
  36. class TranslationEvaluationTrainingSampler(Sampler):
  37. def __init__(self, num_of_samples: int,
  38. batch_size_for_each_input_format: int):
  39. r"""Build a sampler for model training with translation evaluation trainer.
  40. The trainer should derive samples for each subset of the entire dataset.
  41. Args:
  42. num_of_samples: The number of samples in total.
  43. batch_size_for_each_input_format: During training, the batch size for each input format
  44. Returns:
  45. A data sampler for translation evaluation model training.
  46. """
  47. self.num_of_samples = num_of_samples
  48. self.batch_size_for_each_input_format = batch_size_for_each_input_format
  49. self.num_of_samples_for_each_input_format = self.num_of_samples // 3
  50. num_of_samples_to_use = self.num_of_samples_for_each_input_format * 3
  51. logger.info(
  52. '%d samples are given for training. '
  53. 'Using %d samples for each input format. '
  54. 'Leaving the last %d samples unused.' %
  55. (self.num_of_samples, self.num_of_samples_for_each_input_format,
  56. self.num_of_samples - num_of_samples_to_use))
  57. self.num_of_samples = num_of_samples_to_use
  58. random_permutations = torch.randperm(
  59. self.num_of_samples).cpu().tolist()
  60. self.subset_iterators = dict()
  61. self.subset_samplers = dict()
  62. self.indices_for_each_input_format = dict()
  63. for input_format_index, input_format in \
  64. enumerate((InputFormat.SRC_REF, InputFormat.SRC, InputFormat.REF)):
  65. start_idx = input_format_index * self.num_of_samples_for_each_input_format
  66. end_idx = start_idx + self.num_of_samples_for_each_input_format
  67. self.indices_for_each_input_format[
  68. input_format] = random_permutations[start_idx:end_idx]
  69. self.subset_samplers[input_format] = \
  70. BatchSampler(SubsetRandomSampler(self.indices_for_each_input_format[input_format]),
  71. batch_size=self.batch_size_for_each_input_format,
  72. drop_last=True)
  73. self.subset_iterators[input_format] = iter(
  74. self.subset_samplers[input_format])
  75. self.num_of_sampled_batches = 0
  76. if self.__len__() == 0:
  77. raise ValueError(
  78. 'The dataset doesn\'t contain enough examples to form a single batch.',
  79. 'Please reduce the batch_size or use more examples for training.'
  80. )
  81. return
  82. def __iter__(self):
  83. while True:
  84. try:
  85. if self.num_of_sampled_batches == self.__len__():
  86. for input_format in (InputFormat.SRC_REF, InputFormat.SRC,
  87. InputFormat.REF):
  88. while True:
  89. try:
  90. next(self.subset_iterators[input_format])
  91. except StopIteration:
  92. self.subset_iterators[input_format] = \
  93. iter(self.subset_samplers[input_format])
  94. break
  95. self.num_of_sampled_batches = 0
  96. output = list()
  97. for input_format_idx, input_format in \
  98. enumerate((InputFormat.SRC_REF, InputFormat.SRC, InputFormat.REF)):
  99. output += next(self.subset_iterators[input_format])
  100. self.num_of_sampled_batches += 1
  101. yield output
  102. except StopIteration:
  103. break
  104. def __len__(self) -> int:
  105. return self.num_of_samples_for_each_input_format // self.batch_size_for_each_input_format
  106. def convert_csv_dict_to_input(
  107. batch: List[Dict[str, Any]],
  108. preprocessor: Preprocessor) -> Tuple[List[torch.Tensor]]:
  109. input_dict = dict()
  110. for key in batch[0].keys():
  111. input_dict[key] = list(x[key] for x in batch)
  112. input_dict = preprocessor(input_dict)
  113. return input_dict
  114. def data_collate_fn(batch: List[Dict[str, Any]], batch_size: int,
  115. preprocessor: Preprocessor) -> List[Dict[str, Any]]:
  116. output_dict = dict()
  117. output_dict['input_format'] = list()
  118. if preprocessor.mode == ModeKeys.TRAIN:
  119. for input_format_index, input_format in \
  120. enumerate((InputFormat.SRC_REF, InputFormat.SRC, InputFormat.REF)):
  121. start_idx = input_format_index * batch_size
  122. end_idx = start_idx + batch_size
  123. batch_to_process = batch[start_idx:end_idx]
  124. output_dict['input_format'] += [input_format] * batch_size
  125. preprocessor.change_input_format(input_format)
  126. batch_to_process = convert_csv_dict_to_input(
  127. batch_to_process, preprocessor)
  128. for key, value in batch_to_process.items():
  129. if key not in output_dict.keys():
  130. output_dict[key] = list()
  131. output_dict[key].append(value)
  132. elif preprocessor.mode == ModeKeys.EVAL:
  133. output_dict['input_format'] += [preprocessor.input_format] * len(batch)
  134. batch = convert_csv_dict_to_input(batch, preprocessor)
  135. for key, value in batch.items():
  136. if key not in output_dict.keys():
  137. output_dict[key] = list()
  138. output_dict[key].append(value)
  139. else:
  140. raise ValueError(
  141. 'During training, %s mode is not allowed for preprocessor.'
  142. % preprocessor.mode)
  143. input_max_lengths = max(x.size(-1) for x in output_dict['input_ids'])
  144. output_dict['input_ids'] = list(
  145. pad(x,
  146. pad=(0, input_max_lengths - x.size(-1)),
  147. value=preprocessor.pad_token_id) for x in output_dict['input_ids'])
  148. output_dict['input_ids'] = torch.cat(output_dict['input_ids'], dim=0)
  149. output_dict['score'] = torch.Tensor(output_dict['score']).view(-1)
  150. if preprocessor.mode == ModeKeys.EVAL:
  151. output_dict['lp'] = sum(output_dict['lp'], list())
  152. output_dict['raw_score'] = sum(output_dict['raw_score'], list())
  153. output_dict['segment_id'] = sum(output_dict['segment_id'], list())
  154. return output_dict
  155. @TRAINERS.register_module(module_name=Trainers.translation_evaluation_trainer)
  156. class TranslationEvaluationTrainer(EpochBasedTrainer):
  157. def __init__(self,
  158. model: Optional[Union[TorchModel, torch.nn.Module,
  159. str]] = None,
  160. cfg_file: Optional[str] = None,
  161. device: str = 'gpu',
  162. *args,
  163. **kwargs):
  164. r"""Build a translation evaluation trainer with a model dir or a model id in the model hub.
  165. Args:
  166. model: A Model instance.
  167. cfg_file: The path for the configuration file (configuration.json).
  168. device: Used device for this trainer.
  169. """
  170. def data_collator_for_train(x):
  171. return data_collate_fn(
  172. x,
  173. batch_size=self.cfg.train.batch_size,
  174. preprocessor=self.train_preprocessor)
  175. def data_collator_for_eval(x):
  176. return data_collate_fn(
  177. x,
  178. batch_size=self.cfg.evaluation.batch_size,
  179. preprocessor=self.eval_preprocessor)
  180. data_collator = {
  181. ConfigKeys.train: data_collator_for_train,
  182. ConfigKeys.val: data_collator_for_eval
  183. }
  184. super().__init__(
  185. model,
  186. cfg_file=cfg_file,
  187. data_collator=data_collator,
  188. *args,
  189. **kwargs)
  190. self.train_dataloader = None
  191. self.eval_dataloader = None
  192. return
  193. def build_optimizer(self, cfg: ConfigDict) -> Optimizer:
  194. r"""Sets the optimizers to be used during training."""
  195. if self.cfg.train.optimizer.type != 'AdamW':
  196. return super().build_optimizer(cfg)
  197. # Freezing embedding layers for more efficient training.
  198. for param in self.model.encoder.embeddings.parameters():
  199. param.requires_grad = False
  200. logger.info('Building AdamW optimizer ...')
  201. learning_rates_and_parameters = list({
  202. 'params':
  203. self.model.encoder.encoder.layer[i].parameters(),
  204. 'lr':
  205. self.cfg.train.optimizer.plm_lr
  206. * self.cfg.train.optimizer.plm_lr_layerwise_decay**i,
  207. } for i in range(0, self.cfg.model.num_hidden_layers))
  208. learning_rates_and_parameters.append({
  209. 'params':
  210. self.model.encoder.embeddings.parameters(),
  211. 'lr':
  212. self.cfg.train.optimizer.plm_lr,
  213. })
  214. learning_rates_and_parameters.append({
  215. 'params':
  216. self.model.estimator.parameters(),
  217. 'lr':
  218. self.cfg.train.optimizer.mlp_lr
  219. })
  220. learning_rates_and_parameters.append({
  221. 'params':
  222. self.model.layerwise_attention.parameters(),
  223. 'lr':
  224. self.cfg.train.optimizer.mlp_lr,
  225. })
  226. optimizer = AdamW(
  227. learning_rates_and_parameters,
  228. lr=self.cfg.train.optimizer.plm_lr,
  229. betas=self.cfg.train.optimizer.betas,
  230. eps=self.cfg.train.optimizer.eps,
  231. weight_decay=self.cfg.train.optimizer.weight_decay,
  232. )
  233. return optimizer
  234. def get_train_dataloader(self) -> DataLoader:
  235. logger.info('Building dataloader for training ...')
  236. if self.train_dataset is None:
  237. logger.info('Reading train csv file from %s ...'
  238. % self.cfg.dataset.train.name)
  239. self.train_dataset = MsDataset.load(
  240. osp.join(self.model_dir, self.cfg.dataset.train.name),
  241. split=self.cfg.dataset.train.split)
  242. train_dataloader = DataLoader(
  243. self.train_dataset,
  244. batch_sampler=TranslationEvaluationTrainingSampler(
  245. len(self.train_dataset),
  246. batch_size_for_each_input_format=self.cfg.train.batch_size),
  247. num_workers=4,
  248. collate_fn=self.train_data_collator,
  249. generator=None)
  250. logger.info('Reading done, %d items in total'
  251. % len(self.train_dataset))
  252. return train_dataloader
  253. def get_eval_data_loader(self) -> DataLoader:
  254. logger.info('Building dataloader for evaluating ...')
  255. if self.eval_dataset is None:
  256. logger.info('Reading eval csv file from %s ...'
  257. % self.cfg.dataset.valid.name)
  258. self.eval_dataset = MsDataset.load(
  259. osp.join(self.model_dir, self.cfg.dataset.valid.name),
  260. split=self.cfg.dataset.valid.split)
  261. eval_dataloader = DataLoader(
  262. self.eval_dataset,
  263. batch_sampler=BatchSampler(
  264. SequentialSampler(range(0, len(self.eval_dataset))),
  265. batch_size=self.cfg.evaluation.batch_size,
  266. drop_last=False),
  267. num_workers=4,
  268. collate_fn=self.eval_data_collator,
  269. generator=None)
  270. logger.info('Reading done, %d items in total' % len(self.eval_dataset))
  271. return eval_dataloader
  272. def evaluation_loop(self, data_loader, metric_classes):
  273. """ Evaluation loop used by `TranslationEvaluationTrainer.evaluate()`.
  274. The evaluation process of UniTE model should be arranged with three loops,
  275. corresponding to the input formats of `InputFormat.SRC_REF`, `InputFormat.REF`,
  276. and `InputFormat.SRC`.
  277. Here we directly copy the codes of `EpochBasedTrainer.evaluation_loop`, and change
  278. the input format during each evaluation subloop.
  279. """
  280. vis_closure = None
  281. if hasattr(self.cfg.evaluation, 'visualization'):
  282. vis_cfg = self.cfg.evaluation.visualization
  283. vis_closure = partial(
  284. self.visualization, dataset=self.eval_dataset, **vis_cfg)
  285. self.invoke_hook(TrainerStages.before_val)
  286. metric_values = dict()
  287. for input_format in (InputFormat.SRC_REF, InputFormat.SRC,
  288. InputFormat.REF):
  289. self.eval_preprocessor.change_input_format(input_format)
  290. if self._dist:
  291. from modelscope.trainers.utils.inference import multi_gpu_test
  292. # list of batched result and data samples
  293. metric_values.update(
  294. multi_gpu_test(
  295. self,
  296. data_loader,
  297. device=self.device,
  298. metric_classes=metric_classes,
  299. vis_closure=vis_closure,
  300. tmpdir=self.cfg.evaluation.get('cache_dir', None),
  301. gpu_collect=self.cfg.evaluation.get(
  302. 'gpu_collect', False),
  303. data_loader_iters_per_gpu=self._eval_iters_per_epoch))
  304. else:
  305. from modelscope.trainers.utils.inference import single_gpu_test
  306. metric_values.update(
  307. single_gpu_test(
  308. self,
  309. data_loader,
  310. device=self.device,
  311. metric_classes=metric_classes,
  312. vis_closure=vis_closure,
  313. data_loader_iters=self._eval_iters_per_epoch))
  314. for m in metric_classes:
  315. if hasattr(m, 'clear') and callable(m.clear):
  316. m.clear()
  317. self.invoke_hook(TrainerStages.after_val)
  318. return metric_values