| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- import time
- from dataclasses import dataclass
- from typing import Any, Callable, Dict, List, Optional, Tuple, Union
- import numpy as np
- import torch
- from torch import nn
- from torch.utils.data import DataLoader, Dataset
- from tqdm import tqdm
- from modelscope.metainfo import Trainers
- from modelscope.models.base import Model, TorchModel
- from modelscope.models.nlp import BertForTextRanking
- from modelscope.msdatasets.ms_dataset import MsDataset
- from modelscope.preprocessors.base import Preprocessor
- from modelscope.trainers.builder import TRAINERS
- from modelscope.trainers.nlp_trainer import NlpEpochBasedTrainer
- from modelscope.utils.constant import DEFAULT_MODEL_REVISION
- from modelscope.utils.logger import get_logger
- logger = get_logger()
- @dataclass
- class GroupCollator():
- """
- Wrapper that does conversion from List[Tuple[encode_qry, encode_psg]] to List[qry], List[psg]
- and pass batch separately to the actual collator.
- Abstract out data detail for the model.
- """
- def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
- if isinstance(features[0], list):
- features = sum(features, [])
- keys = features[0].keys()
- batch = {k: list() for k in keys}
- for ele in features:
- for k, v in ele.items():
- batch[k].append(v)
- batch = {k: torch.cat(v, dim=0) for k, v in batch.items()}
- return batch
- @TRAINERS.register_module(module_name=Trainers.nlp_text_ranking_trainer)
- class TextRankingTrainer(NlpEpochBasedTrainer):
- def __init__(
- self,
- model: Optional[Union[TorchModel, nn.Module, str]] = None,
- cfg_file: Optional[str] = None,
- cfg_modify_fn: Optional[Callable] = None,
- arg_parse_fn: Optional[Callable] = None,
- data_collator: Optional[Callable] = None,
- train_dataset: Optional[Union[MsDataset, Dataset]] = None,
- eval_dataset: Optional[Union[MsDataset, Dataset]] = None,
- preprocessor: Optional[Preprocessor] = None,
- optimizers: Tuple[torch.optim.Optimizer,
- torch.optim.lr_scheduler._LRScheduler] = (None,
- None),
- model_revision: Optional[str] = DEFAULT_MODEL_REVISION,
- **kwargs):
- if data_collator is None:
- data_collator = GroupCollator()
- super().__init__(
- model=model,
- cfg_file=cfg_file,
- cfg_modify_fn=cfg_modify_fn,
- arg_parse_fn=arg_parse_fn,
- data_collator=data_collator,
- preprocessor=preprocessor,
- optimizers=optimizers,
- train_dataset=train_dataset,
- eval_dataset=eval_dataset,
- model_revision=model_revision,
- **kwargs)
- def compute_mrr(self, result, k=10):
- mrr = 0
- for res in result.values():
- sorted_res = sorted(res, key=lambda x: x[0], reverse=True)
- ar = 0
- for index, ele in enumerate(sorted_res[:k]):
- if str(ele[1]) == '1':
- ar = 1.0 / (index + 1)
- break
- mrr += ar
- return mrr / len(result)
- def compute_ndcg(self, result, k=10):
- ndcg = 0
- from sklearn import ndcg_score
- for res in result.values():
- sorted_res = sorted(res, key=lambda x: [0], reverse=True)
- labels = np.array([[ele[1] for ele in sorted_res]])
- scores = np.array([[ele[0] for ele in sorted_res]])
- ndcg += float(ndcg_score(labels, scores, k=k))
- ndcg = ndcg / len(result)
- return ndcg
- def evaluate(self,
- checkpoint_path: Optional[str] = None,
- *args,
- **kwargs) -> Dict[str, float]:
- """evaluate a dataset
- evaluate a dataset via a specific model from the `checkpoint_path` path, if the `checkpoint_path`
- does not exist, read from the config file.
- Args:
- checkpoint_path (Optional[str], optional): the model path. Defaults to None.
- Returns:
- Dict[str, float]: the results about the evaluation
- Example:
- {"accuracy": 0.5091743119266054, "f1": 0.673780487804878}
- """
- # get the raw online dataset
- self.eval_dataloader = self._build_dataloader_with_dataset(
- self.eval_dataset,
- **self.cfg.evaluation.get('dataloader', {}),
- collate_fn=self.eval_data_collator)
- # generate a standard dataloader
- # generate a model
- if checkpoint_path is not None:
- model = BertForTextRanking.from_pretrained(checkpoint_path)
- else:
- model = self.model
- # copy from easynlp (start)
- model.eval()
- total_samples = 0
- logits_list = list()
- label_list = list()
- qid_list = list()
- total_spent_time = 0.0
- device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
- model.to(device)
- for _step, batch in enumerate(tqdm(self.eval_dataloader)):
- try:
- batch = {
- key:
- val.to(device) if isinstance(val, torch.Tensor) else val
- for key, val in batch.items()
- }
- except RuntimeError:
- batch = {key: val for key, val in batch.items()}
- infer_start_time = time.time()
- with torch.no_grad():
- label_ids = batch.pop('labels').detach().cpu().numpy()
- qids = batch.pop('qid').detach().cpu().numpy()
- outputs = model(**batch)
- infer_end_time = time.time()
- total_spent_time += infer_end_time - infer_start_time
- total_samples += self.eval_dataloader.batch_size
- def sigmoid(logits):
- return np.exp(logits) / (1 + np.exp(logits))
- logits = outputs['logits'].squeeze(-1).detach().cpu().numpy()
- logits = sigmoid(logits).tolist()
- label_list.extend(label_ids)
- logits_list.extend(logits)
- qid_list.extend(qids)
- logger.info('Inference time = {:.2f}s, [{:.4f} ms / sample] '.format(
- total_spent_time, total_spent_time * 1000 / total_samples))
- rank_result = {}
- for qid, score, label in zip(qid_list, logits_list, label_list):
- if qid not in rank_result:
- rank_result[qid] = []
- rank_result[qid].append((score, label))
- for qid in rank_result:
- rank_result[qid] = sorted(rank_result[qid], key=lambda x: x[0])
- eval_outputs = list()
- for metric in self.metrics:
- if metric.startswith('mrr'):
- k = metric.split('@')[-1]
- k = int(k)
- mrr = self.compute_mrr(rank_result, k=k)
- logger.info('{}: {}'.format(metric, mrr))
- eval_outputs.append((metric, mrr))
- elif metric.startswith('ndcg'):
- k = metric.split('@')[-1]
- k = int(k)
- ndcg = self.compute_ndcg(rank_result, k=k)
- logger.info('{}: {}'.format(metric, ndcg))
- eval_outputs.append(('ndcg', ndcg))
- else:
- raise NotImplementedError('Metric %s not implemented' % metric)
- return dict(eval_outputs)
|