| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- import time
- from dataclasses import dataclass
- from typing import Any, Callable, Dict, List, Optional, Tuple, Union
- import numpy as np
- import torch
- from torch import nn
- from torch.utils.data import DataLoader, Dataset
- from tqdm import tqdm
- from transformers import DataCollatorWithPadding
- from modelscope.metainfo import Trainers
- from modelscope.models.base import Model, TorchModel
- from modelscope.models.nlp import BertForTextRanking
- from modelscope.msdatasets.ms_dataset import MsDataset
- from modelscope.preprocessors.base import Preprocessor
- from modelscope.trainers.builder import TRAINERS
- from modelscope.trainers.nlp_trainer import NlpEpochBasedTrainer
- from modelscope.utils.constant import DEFAULT_MODEL_REVISION
- from modelscope.utils.logger import get_logger
- logger = get_logger()
- @dataclass
- class SentenceEmbeddingCollator(DataCollatorWithPadding):
- """
- Wrapper that does conversion from List[Tuple[encode_qry, encode_psg]] to List[qry], List[psg]
- and pass batch separately to the actual collator.
- Abstract out data detail for the model.
- """
- max_length = 128
- tokenizer = None
- def __call__(self, features):
- qq = [f['query'] for f in features]
- dd = [f['docs'] for f in features]
- keys = qq[0].keys()
- qq = {k: [ele[k] for ele in qq] for k in keys}
- q_collated = self.tokenizer._tokenizer.pad(
- qq,
- padding='max_length',
- max_length=self.max_length,
- return_tensors='pt')
- keys = dd[0].keys()
- dd = {k: sum([ele[k] for ele in dd], []) for k in keys}
- d_collated = self.tokenizer._tokenizer.pad(
- dd,
- padding='max_length',
- max_length=self.max_length,
- return_tensors='pt')
- return {'query': q_collated, 'docs': d_collated}
- @TRAINERS.register_module(module_name=Trainers.nlp_sentence_embedding_trainer)
- class SentenceEmbeddingTrainer(NlpEpochBasedTrainer):
- def __init__(
- self,
- model: Optional[Union[TorchModel, nn.Module, str]] = None,
- cfg_file: Optional[str] = None,
- cfg_modify_fn: Optional[Callable] = None,
- arg_parse_fn: Optional[Callable] = None,
- data_collator: Optional[Callable] = None,
- train_dataset: Optional[Union[MsDataset, Dataset]] = None,
- eval_dataset: Optional[Union[MsDataset, Dataset]] = None,
- preprocessor: Optional[Preprocessor] = None,
- optimizers: Tuple[torch.optim.Optimizer,
- torch.optim.lr_scheduler._LRScheduler] = (None,
- None),
- model_revision: Optional[str] = DEFAULT_MODEL_REVISION,
- **kwargs):
- super().__init__(
- model=model,
- cfg_file=cfg_file,
- cfg_modify_fn=cfg_modify_fn,
- arg_parse_fn=arg_parse_fn,
- data_collator=data_collator,
- preprocessor=preprocessor,
- optimizers=optimizers,
- train_dataset=train_dataset,
- eval_dataset=eval_dataset,
- model_revision=model_revision,
- **kwargs)
- def get_data_collator(self, data_collator, **kwargs):
- """Get the data collator for both training and evaluating.
- Args:
- data_collator: The input data_collator param.
- Returns:
- The train_data_collator and eval_data_collator, can be None.
- """
- if data_collator is None:
- data_collator = SentenceEmbeddingCollator(
- tokenizer=self.train_preprocessor.nlp_tokenizer,
- max_length=self.train_preprocessor.max_length)
- return super().get_data_collator(data_collator, **kwargs)
- def evauate(self):
- return {}
|