text_ranking_trainer.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import time
  3. from dataclasses import dataclass
  4. from typing import Any, Callable, Dict, List, Optional, Tuple, Union
  5. import numpy as np
  6. import torch
  7. from torch import nn
  8. from torch.utils.data import DataLoader, Dataset
  9. from tqdm import tqdm
  10. from modelscope.metainfo import Trainers
  11. from modelscope.models.base import Model, TorchModel
  12. from modelscope.models.nlp import BertForTextRanking
  13. from modelscope.msdatasets.ms_dataset import MsDataset
  14. from modelscope.preprocessors.base import Preprocessor
  15. from modelscope.trainers.builder import TRAINERS
  16. from modelscope.trainers.nlp_trainer import NlpEpochBasedTrainer
  17. from modelscope.utils.constant import DEFAULT_MODEL_REVISION
  18. from modelscope.utils.logger import get_logger
  19. logger = get_logger()
  20. @dataclass
  21. class GroupCollator():
  22. """
  23. Wrapper that does conversion from List[Tuple[encode_qry, encode_psg]] to List[qry], List[psg]
  24. and pass batch separately to the actual collator.
  25. Abstract out data detail for the model.
  26. """
  27. def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
  28. if isinstance(features[0], list):
  29. features = sum(features, [])
  30. keys = features[0].keys()
  31. batch = {k: list() for k in keys}
  32. for ele in features:
  33. for k, v in ele.items():
  34. batch[k].append(v)
  35. batch = {k: torch.cat(v, dim=0) for k, v in batch.items()}
  36. return batch
  37. @TRAINERS.register_module(module_name=Trainers.nlp_text_ranking_trainer)
  38. class TextRankingTrainer(NlpEpochBasedTrainer):
  39. def __init__(
  40. self,
  41. model: Optional[Union[TorchModel, nn.Module, str]] = None,
  42. cfg_file: Optional[str] = None,
  43. cfg_modify_fn: Optional[Callable] = None,
  44. arg_parse_fn: Optional[Callable] = None,
  45. data_collator: Optional[Callable] = None,
  46. train_dataset: Optional[Union[MsDataset, Dataset]] = None,
  47. eval_dataset: Optional[Union[MsDataset, Dataset]] = None,
  48. preprocessor: Optional[Preprocessor] = None,
  49. optimizers: Tuple[torch.optim.Optimizer,
  50. torch.optim.lr_scheduler._LRScheduler] = (None,
  51. None),
  52. model_revision: Optional[str] = DEFAULT_MODEL_REVISION,
  53. **kwargs):
  54. if data_collator is None:
  55. data_collator = GroupCollator()
  56. super().__init__(
  57. model=model,
  58. cfg_file=cfg_file,
  59. cfg_modify_fn=cfg_modify_fn,
  60. arg_parse_fn=arg_parse_fn,
  61. data_collator=data_collator,
  62. preprocessor=preprocessor,
  63. optimizers=optimizers,
  64. train_dataset=train_dataset,
  65. eval_dataset=eval_dataset,
  66. model_revision=model_revision,
  67. **kwargs)
  68. def compute_mrr(self, result, k=10):
  69. mrr = 0
  70. for res in result.values():
  71. sorted_res = sorted(res, key=lambda x: x[0], reverse=True)
  72. ar = 0
  73. for index, ele in enumerate(sorted_res[:k]):
  74. if str(ele[1]) == '1':
  75. ar = 1.0 / (index + 1)
  76. break
  77. mrr += ar
  78. return mrr / len(result)
  79. def compute_ndcg(self, result, k=10):
  80. ndcg = 0
  81. from sklearn import ndcg_score
  82. for res in result.values():
  83. sorted_res = sorted(res, key=lambda x: [0], reverse=True)
  84. labels = np.array([[ele[1] for ele in sorted_res]])
  85. scores = np.array([[ele[0] for ele in sorted_res]])
  86. ndcg += float(ndcg_score(labels, scores, k=k))
  87. ndcg = ndcg / len(result)
  88. return ndcg
  89. def evaluate(self,
  90. checkpoint_path: Optional[str] = None,
  91. *args,
  92. **kwargs) -> Dict[str, float]:
  93. """evaluate a dataset
  94. evaluate a dataset via a specific model from the `checkpoint_path` path, if the `checkpoint_path`
  95. does not exist, read from the config file.
  96. Args:
  97. checkpoint_path (Optional[str], optional): the model path. Defaults to None.
  98. Returns:
  99. Dict[str, float]: the results about the evaluation
  100. Example:
  101. {"accuracy": 0.5091743119266054, "f1": 0.673780487804878}
  102. """
  103. # get the raw online dataset
  104. self.eval_dataloader = self._build_dataloader_with_dataset(
  105. self.eval_dataset,
  106. **self.cfg.evaluation.get('dataloader', {}),
  107. collate_fn=self.eval_data_collator)
  108. # generate a standard dataloader
  109. # generate a model
  110. if checkpoint_path is not None:
  111. model = BertForTextRanking.from_pretrained(checkpoint_path)
  112. else:
  113. model = self.model
  114. # copy from easynlp (start)
  115. model.eval()
  116. total_samples = 0
  117. logits_list = list()
  118. label_list = list()
  119. qid_list = list()
  120. total_spent_time = 0.0
  121. device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
  122. model.to(device)
  123. for _step, batch in enumerate(tqdm(self.eval_dataloader)):
  124. try:
  125. batch = {
  126. key:
  127. val.to(device) if isinstance(val, torch.Tensor) else val
  128. for key, val in batch.items()
  129. }
  130. except RuntimeError:
  131. batch = {key: val for key, val in batch.items()}
  132. infer_start_time = time.time()
  133. with torch.no_grad():
  134. label_ids = batch.pop('labels').detach().cpu().numpy()
  135. qids = batch.pop('qid').detach().cpu().numpy()
  136. outputs = model(**batch)
  137. infer_end_time = time.time()
  138. total_spent_time += infer_end_time - infer_start_time
  139. total_samples += self.eval_dataloader.batch_size
  140. def sigmoid(logits):
  141. return np.exp(logits) / (1 + np.exp(logits))
  142. logits = outputs['logits'].squeeze(-1).detach().cpu().numpy()
  143. logits = sigmoid(logits).tolist()
  144. label_list.extend(label_ids)
  145. logits_list.extend(logits)
  146. qid_list.extend(qids)
  147. logger.info('Inference time = {:.2f}s, [{:.4f} ms / sample] '.format(
  148. total_spent_time, total_spent_time * 1000 / total_samples))
  149. rank_result = {}
  150. for qid, score, label in zip(qid_list, logits_list, label_list):
  151. if qid not in rank_result:
  152. rank_result[qid] = []
  153. rank_result[qid].append((score, label))
  154. for qid in rank_result:
  155. rank_result[qid] = sorted(rank_result[qid], key=lambda x: x[0])
  156. eval_outputs = list()
  157. for metric in self.metrics:
  158. if metric.startswith('mrr'):
  159. k = metric.split('@')[-1]
  160. k = int(k)
  161. mrr = self.compute_mrr(rank_result, k=k)
  162. logger.info('{}: {}'.format(metric, mrr))
  163. eval_outputs.append((metric, mrr))
  164. elif metric.startswith('ndcg'):
  165. k = metric.split('@')[-1]
  166. k = int(k)
  167. ndcg = self.compute_ndcg(rank_result, k=k)
  168. logger.info('{}: {}'.format(metric, ndcg))
  169. eval_outputs.append(('ndcg', ndcg))
  170. else:
  171. raise NotImplementedError('Metric %s not implemented' % metric)
  172. return dict(eval_outputs)