sequence_classification_trainer.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import time
  3. from typing import Dict, Optional, Tuple, Union
  4. import numpy as np
  5. from modelscope.metainfo import Trainers
  6. from modelscope.trainers.base import BaseTrainer
  7. from modelscope.trainers.builder import TRAINERS
  8. from modelscope.utils.logger import get_logger
  9. PATH = None
  10. logger = get_logger(PATH)
  11. @TRAINERS.register_module(module_name=Trainers.bert_sentiment_analysis)
  12. class SequenceClassificationTrainer(BaseTrainer):
  13. def __init__(self, cfg_file: str, *args, **kwargs):
  14. """ A trainer is used for Sequence Classification
  15. Based on Config file (*.yaml or *.json), the trainer trains or evaluates on a dataset
  16. Args:
  17. cfg_file (str): the path of config file
  18. Raises:
  19. ValueError: _description_
  20. """
  21. super().__init__(cfg_file)
  22. def train(self, *args, **kwargs):
  23. logger.info('Train')
  24. ...
  25. def __attr_is_exist(self, attr: str) -> Tuple[Union[str, bool]]:
  26. """get attribute from config, if the attribute does exist, return false
  27. Example:
  28. >>> self.__attr_is_exist("model path")
  29. >>> out: (model-path, "/workspace/bert-base-sst2")
  30. >>> self.__attr_is_exist("model weights")
  31. >>> out: (model-weights, False)
  32. Args:
  33. attr (str): attribute str, "model path" -> config["model"][path]
  34. Returns:
  35. Tuple[Union[str, bool]]:[target attribute name, the target attribute or False]
  36. """
  37. paths = attr.split(' ')
  38. attr_str: str = '-'.join(paths)
  39. target = self.cfg[paths[0]] if hasattr(self.cfg, paths[0]) else None
  40. for path_ in paths[1:]:
  41. if not hasattr(target, path_):
  42. return attr_str, False
  43. target = target[path_]
  44. if target and target != '':
  45. return attr_str, target
  46. return attr_str, False
  47. def evaluate(self,
  48. checkpoint_path: Optional[str] = None,
  49. *args,
  50. **kwargs) -> Dict[str, float]:
  51. """evaluate a dataset
  52. evaluate a dataset via a specific model from the `checkpoint_path` path, if the `checkpoint_path`
  53. does not exist, read from the config file.
  54. Args:
  55. checkpoint_path (Optional[str], optional): the model path. Defaults to None.
  56. Returns:
  57. Dict[str, float]: the results about the evaluation
  58. Example:
  59. {"accuracy": 0.5091743119266054, "f1": 0.673780487804878}
  60. """
  61. import torch
  62. from easynlp.appzoo import load_dataset
  63. from easynlp.appzoo.dataset import GeneralDataset
  64. from easynlp.appzoo.sequence_classification.model import \
  65. SequenceClassification
  66. from easynlp.utils import losses
  67. from sklearn.metrics import f1_score
  68. from torch.utils.data import DataLoader
  69. raise_str = 'Attribute {} is not given in config file!'
  70. metrics = self.__attr_is_exist('evaluation metrics')
  71. eval_batch_size = self.__attr_is_exist('evaluation batch_size')
  72. test_dataset_path = self.__attr_is_exist('dataset valid file')
  73. attrs = [metrics, eval_batch_size, test_dataset_path]
  74. for attr_ in attrs:
  75. if not attr_[-1]:
  76. raise AttributeError(raise_str.format(attr_[0]))
  77. if not checkpoint_path:
  78. checkpoint_path = self.__attr_is_exist('evaluation model_path')[-1]
  79. if not checkpoint_path:
  80. raise ValueError(
  81. 'Argument checkout_path must be passed if the evaluation-model_path is not given in config file!'
  82. )
  83. max_sequence_length = kwargs.get(
  84. 'max_sequence_length',
  85. self.__attr_is_exist('evaluation max_sequence_length')[-1])
  86. if not max_sequence_length:
  87. raise ValueError(
  88. 'Argument max_sequence_length must be passed '
  89. 'if the evaluation-max_sequence_length does not exist in config file!'
  90. )
  91. # get the raw online dataset
  92. raw_dataset = load_dataset(*test_dataset_path[-1].split('/'))
  93. valid_dataset = raw_dataset['validation']
  94. # generate a standard dataloader
  95. pre_dataset = GeneralDataset(valid_dataset, checkpoint_path,
  96. max_sequence_length)
  97. valid_dataloader = DataLoader(
  98. pre_dataset,
  99. batch_size=eval_batch_size[-1],
  100. shuffle=False,
  101. collate_fn=pre_dataset.batch_fn)
  102. # generate a model
  103. model = SequenceClassification.from_pretrained(checkpoint_path)
  104. # copy from easynlp (start)
  105. model.eval()
  106. total_loss = 0
  107. total_steps = 0
  108. total_samples = 0
  109. hit_num = 0
  110. total_num = 0
  111. logits_list = list()
  112. y_trues = list()
  113. total_spent_time = 0.0
  114. device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
  115. model.to(device)
  116. for _step, batch in enumerate(valid_dataloader):
  117. try:
  118. batch = {
  119. # key: val.cuda() if isinstance(val, torch.Tensor) else val
  120. # for key, val in batch.items()
  121. key:
  122. val.to(device) if isinstance(val, torch.Tensor) else val
  123. for key, val in batch.items()
  124. }
  125. except RuntimeError:
  126. batch = {key: val for key, val in batch.items()}
  127. infer_start_time = time.time()
  128. with torch.no_grad():
  129. label_ids = batch.pop('label_ids')
  130. outputs = model(batch)
  131. infer_end_time = time.time()
  132. total_spent_time += infer_end_time - infer_start_time
  133. assert 'logits' in outputs
  134. logits = outputs['logits']
  135. y_trues.extend(label_ids.tolist())
  136. logits_list.extend(logits.tolist())
  137. hit_num += torch.sum(
  138. torch.argmax(logits, dim=-1) == label_ids).item()
  139. total_num += label_ids.shape[0]
  140. if len(logits.shape) == 1 or logits.shape[-1] == 1:
  141. tmp_loss = losses.mse_loss(logits, label_ids)
  142. elif len(logits.shape) == 2:
  143. tmp_loss = losses.cross_entropy(logits, label_ids)
  144. else:
  145. raise RuntimeError
  146. total_loss += tmp_loss.mean().item()
  147. total_steps += 1
  148. total_samples += valid_dataloader.batch_size
  149. if (_step + 1) % 100 == 0:
  150. total_step = len(
  151. valid_dataloader.dataset) // valid_dataloader.batch_size
  152. logger.info('Eval: {}/{} steps finished'.format(
  153. _step + 1, total_step))
  154. logger.info('Inference time = {:.2f}s, [{:.4f} ms / sample] '.format(
  155. total_spent_time, total_spent_time * 1000 / total_samples))
  156. eval_loss = total_loss / total_steps
  157. logger.info('Eval loss: {}'.format(eval_loss))
  158. logits_list = np.array(logits_list)
  159. eval_outputs = list()
  160. for metric in metrics[-1]:
  161. if metric.endswith('accuracy'):
  162. acc = hit_num / total_num
  163. logger.info('Accuracy: {}'.format(acc))
  164. eval_outputs.append(('accuracy', acc))
  165. elif metric == 'f1':
  166. if model.config.num_labels == 2:
  167. f1 = f1_score(y_trues, np.argmax(logits_list, axis=-1))
  168. logger.info('F1: {}'.format(f1))
  169. eval_outputs.append(('f1', f1))
  170. else:
  171. f1 = f1_score(
  172. y_trues,
  173. np.argmax(logits_list, axis=-1),
  174. average='macro')
  175. logger.info('Macro F1: {}'.format(f1))
  176. eval_outputs.append(('macro-f1', f1))
  177. f1 = f1_score(
  178. y_trues,
  179. np.argmax(logits_list, axis=-1),
  180. average='micro')
  181. logger.info('Micro F1: {}'.format(f1))
  182. eval_outputs.append(('micro-f1', f1))
  183. else:
  184. raise NotImplementedError('Metric %s not implemented' % metric)
  185. # copy from easynlp (end)
  186. return dict(eval_outputs)