| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- """PyTorch trainer for UniTE model."""
- import os.path as osp
- import random
- from math import ceil
- from os import mkdir
- from typing import Any, Callable, Dict, List, Optional, Tuple, Union
- import torch
- from pandas import DataFrame
- from torch.nn.functional import pad
- from torch.nn.utils import clip_grad_norm_
- from torch.optim import AdamW, Optimizer
- from torch.utils.data import (BatchSampler, DataLoader, Dataset, Sampler,
- SequentialSampler, SubsetRandomSampler)
- from torch.utils.tensorboard import SummaryWriter
- from tqdm import tqdm
- from transformers import AutoTokenizer
- from modelscope.metainfo import Metrics, Trainers
- from modelscope.metrics import Metric
- from modelscope.metrics.builder import MetricKeys, build_metric
- from modelscope.models.base import TorchModel
- from modelscope.models.nlp.unite.configuration import InputFormat
- from modelscope.models.nlp.unite.translation_evaluation import (
- UniTEForTranslationEvaluation, combine_input_sentences)
- from modelscope.msdatasets import MsDataset
- from modelscope.preprocessors import Preprocessor
- from modelscope.trainers.builder import TRAINERS
- from modelscope.trainers.hooks import Hook
- from modelscope.trainers.trainer import EpochBasedTrainer
- from modelscope.utils.config import ConfigDict
- from modelscope.utils.constant import (ConfigKeys, Fields, ModeKeys, ModelFile,
- TrainerStages)
- from modelscope.utils.device import create_device
- from modelscope.utils.logger import get_logger
- logger = get_logger()
- class TranslationEvaluationTrainingSampler(Sampler):
- def __init__(self, num_of_samples: int,
- batch_size_for_each_input_format: int):
- r"""Build a sampler for model training with translation evaluation trainer.
- The trainer should derive samples for each subset of the entire dataset.
- Args:
- num_of_samples: The number of samples in total.
- batch_size_for_each_input_format: During training, the batch size for each input format
- Returns:
- A data sampler for translation evaluation model training.
- """
- self.num_of_samples = num_of_samples
- self.batch_size_for_each_input_format = batch_size_for_each_input_format
- self.num_of_samples_for_each_input_format = self.num_of_samples // 3
- num_of_samples_to_use = self.num_of_samples_for_each_input_format * 3
- logger.info(
- '%d samples are given for training. '
- 'Using %d samples for each input format. '
- 'Leaving the last %d samples unused.' %
- (self.num_of_samples, self.num_of_samples_for_each_input_format,
- self.num_of_samples - num_of_samples_to_use))
- self.num_of_samples = num_of_samples_to_use
- random_permutations = torch.randperm(
- self.num_of_samples).cpu().tolist()
- self.subset_iterators = dict()
- self.subset_samplers = dict()
- self.indices_for_each_input_format = dict()
- for input_format_index, input_format in \
- enumerate((InputFormat.SRC_REF, InputFormat.SRC, InputFormat.REF)):
- start_idx = input_format_index * self.num_of_samples_for_each_input_format
- end_idx = start_idx + self.num_of_samples_for_each_input_format
- self.indices_for_each_input_format[
- input_format] = random_permutations[start_idx:end_idx]
- self.subset_samplers[input_format] = \
- BatchSampler(SubsetRandomSampler(self.indices_for_each_input_format[input_format]),
- batch_size=self.batch_size_for_each_input_format,
- drop_last=True)
- self.subset_iterators[input_format] = iter(
- self.subset_samplers[input_format])
- self.num_of_sampled_batches = 0
- if self.__len__() == 0:
- raise ValueError(
- 'The dataset doesn\'t contain enough examples to form a single batch.',
- 'Please reduce the batch_size or use more examples for training.'
- )
- return
- def __iter__(self):
- while True:
- try:
- if self.num_of_sampled_batches == self.__len__():
- for input_format in (InputFormat.SRC_REF, InputFormat.SRC,
- InputFormat.REF):
- while True:
- try:
- next(self.subset_iterators[input_format])
- except StopIteration:
- self.subset_iterators[input_format] = \
- iter(self.subset_samplers[input_format])
- break
- self.num_of_sampled_batches = 0
- output = list()
- for input_format_idx, input_format in \
- enumerate((InputFormat.SRC_REF, InputFormat.SRC, InputFormat.REF)):
- output += next(self.subset_iterators[input_format])
- self.num_of_sampled_batches += 1
- yield output
- except StopIteration:
- break
- def __len__(self) -> int:
- return self.num_of_samples_for_each_input_format // self.batch_size_for_each_input_format
- def convert_csv_dict_to_input(
- batch: List[Dict[str, Any]],
- preprocessor: Preprocessor) -> Tuple[List[torch.Tensor]]:
- input_dict = dict()
- for key in batch[0].keys():
- input_dict[key] = list(x[key] for x in batch)
- input_dict = preprocessor(input_dict)
- return input_dict
- def data_collate_fn(batch: List[Dict[str, Any]], batch_size: int,
- preprocessor: Preprocessor) -> List[Dict[str, Any]]:
- output_dict = dict()
- output_dict['input_format'] = list()
- if preprocessor.mode == ModeKeys.TRAIN:
- for input_format_index, input_format in \
- enumerate((InputFormat.SRC_REF, InputFormat.SRC, InputFormat.REF)):
- start_idx = input_format_index * batch_size
- end_idx = start_idx + batch_size
- batch_to_process = batch[start_idx:end_idx]
- output_dict['input_format'] += [input_format] * batch_size
- preprocessor.change_input_format(input_format)
- batch_to_process = convert_csv_dict_to_input(
- batch_to_process, preprocessor)
- for key, value in batch_to_process.items():
- if key not in output_dict.keys():
- output_dict[key] = list()
- output_dict[key].append(value)
- elif preprocessor.mode == ModeKeys.EVAL:
- output_dict['input_format'] += [preprocessor.input_format] * len(batch)
- batch = convert_csv_dict_to_input(batch, preprocessor)
- for key, value in batch.items():
- if key not in output_dict.keys():
- output_dict[key] = list()
- output_dict[key].append(value)
- else:
- raise ValueError(
- 'During training, %s mode is not allowed for preprocessor.'
- % preprocessor.mode)
- input_max_lengths = max(x.size(-1) for x in output_dict['input_ids'])
- output_dict['input_ids'] = list(
- pad(x,
- pad=(0, input_max_lengths - x.size(-1)),
- value=preprocessor.pad_token_id) for x in output_dict['input_ids'])
- output_dict['input_ids'] = torch.cat(output_dict['input_ids'], dim=0)
- output_dict['score'] = torch.Tensor(output_dict['score']).view(-1)
- if preprocessor.mode == ModeKeys.EVAL:
- output_dict['lp'] = sum(output_dict['lp'], list())
- output_dict['raw_score'] = sum(output_dict['raw_score'], list())
- output_dict['segment_id'] = sum(output_dict['segment_id'], list())
- return output_dict
- @TRAINERS.register_module(module_name=Trainers.translation_evaluation_trainer)
- class TranslationEvaluationTrainer(EpochBasedTrainer):
- def __init__(self,
- model: Optional[Union[TorchModel, torch.nn.Module,
- str]] = None,
- cfg_file: Optional[str] = None,
- device: str = 'gpu',
- *args,
- **kwargs):
- r"""Build a translation evaluation trainer with a model dir or a model id in the model hub.
- Args:
- model: A Model instance.
- cfg_file: The path for the configuration file (configuration.json).
- device: Used device for this trainer.
- """
- def data_collator_for_train(x):
- return data_collate_fn(
- x,
- batch_size=self.cfg.train.batch_size,
- preprocessor=self.train_preprocessor)
- def data_collator_for_eval(x):
- return data_collate_fn(
- x,
- batch_size=self.cfg.evaluation.batch_size,
- preprocessor=self.eval_preprocessor)
- data_collator = {
- ConfigKeys.train: data_collator_for_train,
- ConfigKeys.val: data_collator_for_eval
- }
- super().__init__(
- model,
- cfg_file=cfg_file,
- data_collator=data_collator,
- *args,
- **kwargs)
- self.train_dataloader = None
- self.eval_dataloader = None
- return
- def build_optimizer(self, cfg: ConfigDict) -> Optimizer:
- r"""Sets the optimizers to be used during training."""
- if self.cfg.train.optimizer.type != 'AdamW':
- return super().build_optimizer(cfg)
- # Freezing embedding layers for more efficient training.
- for param in self.model.encoder.embeddings.parameters():
- param.requires_grad = False
- logger.info('Building AdamW optimizer ...')
- learning_rates_and_parameters = list({
- 'params':
- self.model.encoder.encoder.layer[i].parameters(),
- 'lr':
- self.cfg.train.optimizer.plm_lr
- * self.cfg.train.optimizer.plm_lr_layerwise_decay**i,
- } for i in range(0, self.cfg.model.num_hidden_layers))
- learning_rates_and_parameters.append({
- 'params':
- self.model.encoder.embeddings.parameters(),
- 'lr':
- self.cfg.train.optimizer.plm_lr,
- })
- learning_rates_and_parameters.append({
- 'params':
- self.model.estimator.parameters(),
- 'lr':
- self.cfg.train.optimizer.mlp_lr
- })
- learning_rates_and_parameters.append({
- 'params':
- self.model.layerwise_attention.parameters(),
- 'lr':
- self.cfg.train.optimizer.mlp_lr,
- })
- optimizer = AdamW(
- learning_rates_and_parameters,
- lr=self.cfg.train.optimizer.plm_lr,
- betas=self.cfg.train.optimizer.betas,
- eps=self.cfg.train.optimizer.eps,
- weight_decay=self.cfg.train.optimizer.weight_decay,
- )
- return optimizer
- def get_train_dataloader(self) -> DataLoader:
- logger.info('Building dataloader for training ...')
- if self.train_dataset is None:
- logger.info('Reading train csv file from %s ...'
- % self.cfg.dataset.train.name)
- self.train_dataset = MsDataset.load(
- osp.join(self.model_dir, self.cfg.dataset.train.name),
- split=self.cfg.dataset.train.split)
- train_dataloader = DataLoader(
- self.train_dataset,
- batch_sampler=TranslationEvaluationTrainingSampler(
- len(self.train_dataset),
- batch_size_for_each_input_format=self.cfg.train.batch_size),
- num_workers=4,
- collate_fn=self.train_data_collator,
- generator=None)
- logger.info('Reading done, %d items in total'
- % len(self.train_dataset))
- return train_dataloader
- def get_eval_data_loader(self) -> DataLoader:
- logger.info('Building dataloader for evaluating ...')
- if self.eval_dataset is None:
- logger.info('Reading eval csv file from %s ...'
- % self.cfg.dataset.valid.name)
- self.eval_dataset = MsDataset.load(
- osp.join(self.model_dir, self.cfg.dataset.valid.name),
- split=self.cfg.dataset.valid.split)
- eval_dataloader = DataLoader(
- self.eval_dataset,
- batch_sampler=BatchSampler(
- SequentialSampler(range(0, len(self.eval_dataset))),
- batch_size=self.cfg.evaluation.batch_size,
- drop_last=False),
- num_workers=4,
- collate_fn=self.eval_data_collator,
- generator=None)
- logger.info('Reading done, %d items in total' % len(self.eval_dataset))
- return eval_dataloader
- def evaluation_loop(self, data_loader, metric_classes):
- """ Evaluation loop used by `TranslationEvaluationTrainer.evaluate()`.
- The evaluation process of UniTE model should be arranged with three loops,
- corresponding to the input formats of `InputFormat.SRC_REF`, `InputFormat.REF`,
- and `InputFormat.SRC`.
- Here we directly copy the codes of `EpochBasedTrainer.evaluation_loop`, and change
- the input format during each evaluation subloop.
- """
- vis_closure = None
- if hasattr(self.cfg.evaluation, 'visualization'):
- vis_cfg = self.cfg.evaluation.visualization
- vis_closure = partial(
- self.visualization, dataset=self.eval_dataset, **vis_cfg)
- self.invoke_hook(TrainerStages.before_val)
- metric_values = dict()
- for input_format in (InputFormat.SRC_REF, InputFormat.SRC,
- InputFormat.REF):
- self.eval_preprocessor.change_input_format(input_format)
- if self._dist:
- from modelscope.trainers.utils.inference import multi_gpu_test
- # list of batched result and data samples
- metric_values.update(
- multi_gpu_test(
- self,
- data_loader,
- device=self.device,
- metric_classes=metric_classes,
- vis_closure=vis_closure,
- tmpdir=self.cfg.evaluation.get('cache_dir', None),
- gpu_collect=self.cfg.evaluation.get(
- 'gpu_collect', False),
- data_loader_iters_per_gpu=self._eval_iters_per_epoch))
- else:
- from modelscope.trainers.utils.inference import single_gpu_test
- metric_values.update(
- single_gpu_test(
- self,
- data_loader,
- device=self.device,
- metric_classes=metric_classes,
- vis_closure=vis_closure,
- data_loader_iters=self._eval_iters_per_epoch))
- for m in metric_classes:
- if hasattr(m, 'clear') and callable(m.clear):
- m.clear()
- self.invoke_hook(TrainerStages.after_val)
- return metric_values
|