nlp_trainer.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. from typing import Tuple, Union
  4. import numpy as np
  5. from torch import nn
  6. from modelscope.metainfo import Trainers
  7. from modelscope.metrics.builder import build_metric
  8. from modelscope.models.base import Model, TorchModel
  9. from modelscope.preprocessors import Preprocessor
  10. from modelscope.utils.config import Config
  11. from modelscope.utils.constant import ModeKeys
  12. from .base import TRAINERS
  13. from .trainer import EpochBasedTrainer
  14. @TRAINERS.register_module(module_name=Trainers.nlp_base_trainer)
  15. class NlpEpochBasedTrainer(EpochBasedTrainer):
  16. """Add code to adapt with nlp models.
  17. This trainer will accept the information of labels&text keys in the cfg, and then initialize
  18. the nlp models/preprocessors with this information.
  19. Labels&text key information may be carried in the cfg like this:
  20. >>> cfg = {
  21. >>> ...
  22. >>> "dataset": {
  23. >>> "train": {
  24. >>> "first_sequence": "text1",
  25. >>> "second_sequence": "text2",
  26. >>> "label": "label",
  27. >>> "labels": [1, 2, 3, 4],
  28. >>> },
  29. >>> "val": {
  30. >>> "first_sequence": "text3",
  31. >>> "second_sequence": "text4",
  32. >>> "label": "label2",
  33. >>> },
  34. >>> }
  35. >>> }
  36. To view some actual finetune examples, please check the test files listed below:
  37. tests/trainers/test_finetune_sequence_classification.py
  38. tests/trainers/test_finetune_token_classification.py
  39. """
  40. def __init__(self, *args, **kwargs):
  41. self.label2id = None
  42. self.id2label = None
  43. self.num_labels = None
  44. self.train_keys = None
  45. self.eval_keys = None
  46. super().__init__(*args, **kwargs)
  47. def prepare_labels(self, cfg):
  48. try:
  49. labels = cfg.dataset.train.labels
  50. self.label2id = {label: idx for idx, label in enumerate(labels)}
  51. self.id2label = {idx: label for idx, label in enumerate(labels)}
  52. self.num_labels = len(labels)
  53. except AttributeError:
  54. pass
  55. def build_dataset_keys(cfg):
  56. if cfg is not None:
  57. input_keys = {
  58. 'first_sequence': getattr(cfg, 'first_sequence', None),
  59. 'second_sequence': getattr(cfg, 'second_sequence', None),
  60. 'label': getattr(cfg, 'label', None),
  61. }
  62. else:
  63. input_keys = {}
  64. return {k: v for k, v in input_keys.items() if v is not None}
  65. self.train_keys = build_dataset_keys(cfg.safe_get('dataset.train'))
  66. self.eval_keys = build_dataset_keys(cfg.safe_get('dataset.val'))
  67. if len(self.eval_keys) == 0:
  68. self.eval_keys = self.train_keys
  69. def rebuild_config(self, cfg: Config):
  70. if self.cfg_modify_fn is not None:
  71. cfg = self.cfg_modify_fn(cfg)
  72. self.prepare_labels(cfg)
  73. if not hasattr(cfg.model, 'label2id') and not hasattr(
  74. cfg.model, 'id2label'):
  75. if self.id2label is not None:
  76. cfg.model['id2label'] = self.id2label
  77. if self.label2id is not None:
  78. cfg.model['label2id'] = self.label2id
  79. return cfg
  80. def build_model(self) -> Union[nn.Module, TorchModel]:
  81. """ Instantiate a pytorch model and return.
  82. By default, we will create a model using config from configuration file. You can
  83. override this method in a subclass.
  84. """
  85. model_args = {} if self.num_labels is None else {
  86. 'num_labels': self.num_labels
  87. }
  88. model = Model.from_pretrained(
  89. self.model_dir, cfg_dict=self.cfg, **model_args)
  90. if not isinstance(model, nn.Module) and hasattr(model, 'model'):
  91. return model.model
  92. elif isinstance(model, nn.Module):
  93. return model
  94. def build_preprocessor(self) -> Tuple[Preprocessor, Preprocessor]:
  95. """Build the preprocessor.
  96. User can override this method to implement custom logits.
  97. Returns: The preprocessor instance.
  98. """
  99. # Compatible with old logic
  100. extra_args = {} if self.label2id is None else {
  101. 'label2id': self.label2id
  102. }
  103. train_preprocessor = Preprocessor.from_pretrained(
  104. self.model_dir,
  105. cfg_dict=self.cfg,
  106. preprocessor_mode=ModeKeys.TRAIN,
  107. **extra_args,
  108. **self.train_keys,
  109. mode=ModeKeys.TRAIN,
  110. use_fast=True)
  111. eval_preprocessor = Preprocessor.from_pretrained(
  112. self.model_dir,
  113. cfg_dict=self.cfg,
  114. preprocessor_mode=ModeKeys.EVAL,
  115. **extra_args,
  116. **self.eval_keys,
  117. mode=ModeKeys.EVAL,
  118. use_fast=True)
  119. return train_preprocessor, eval_preprocessor
  120. @TRAINERS.register_module(module_name=Trainers.nlp_veco_trainer)
  121. class VecoTrainer(NlpEpochBasedTrainer):
  122. def evaluate(self, checkpoint_path=None):
  123. """Veco evaluates the datasets one by one.
  124. """
  125. from modelscope.msdatasets.dataset_cls.custom_datasets import VecoDataset
  126. if checkpoint_path is not None:
  127. from modelscope.trainers.hooks import LoadCheckpointHook
  128. LoadCheckpointHook.load_checkpoint(checkpoint_path, self)
  129. self.model.eval()
  130. self._mode = ModeKeys.EVAL
  131. metric_values = {}
  132. if self.eval_dataset is None:
  133. self.eval_dataset = self.build_dataset_from_cfg(
  134. model_cfg=self.cfg,
  135. mode=self._mode,
  136. preprocessor=self.eval_preprocessor)
  137. idx = 0
  138. dataset_cnt = 1
  139. if isinstance(self.eval_dataset, VecoDataset):
  140. self.eval_dataset.switch_dataset(idx)
  141. dataset_cnt = len(self.eval_dataset.datasets)
  142. while True:
  143. self.eval_dataloader = self._build_dataloader_with_dataset(
  144. self.eval_dataset, **self.cfg.evaluation.get('dataloader', {}))
  145. self.data_loader = self.eval_dataloader
  146. metric_classes = [build_metric(metric) for metric in self.metrics]
  147. for m in metric_classes:
  148. m.trainer = self
  149. self.evaluation_loop(self.eval_dataloader, metric_classes)
  150. for m_idx, metric_cls in enumerate(metric_classes):
  151. if f'eval_dataset[{idx}]' not in metric_values:
  152. metric_values[f'eval_dataset[{idx}]'] = {}
  153. metric_values[f'eval_dataset[{idx}]'][
  154. self.metrics[m_idx]] = metric_cls.evaluate()
  155. idx += 1
  156. if idx < dataset_cnt:
  157. self.eval_dataset.switch_dataset(idx)
  158. else:
  159. break
  160. for metric_name in self.metrics:
  161. all_metrics = [m[metric_name] for m in metric_values.values()]
  162. for key in all_metrics[0].keys():
  163. metric_values[key] = np.average(
  164. [metric[key] for metric in all_metrics])
  165. return metric_values