sentence_embedding_trainer.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import time
  3. from dataclasses import dataclass
  4. from typing import Any, Callable, Dict, List, Optional, Tuple, Union
  5. import numpy as np
  6. import torch
  7. from torch import nn
  8. from torch.utils.data import DataLoader, Dataset
  9. from tqdm import tqdm
  10. from transformers import DataCollatorWithPadding
  11. from modelscope.metainfo import Trainers
  12. from modelscope.models.base import Model, TorchModel
  13. from modelscope.models.nlp import BertForTextRanking
  14. from modelscope.msdatasets.ms_dataset import MsDataset
  15. from modelscope.preprocessors.base import Preprocessor
  16. from modelscope.trainers.builder import TRAINERS
  17. from modelscope.trainers.nlp_trainer import NlpEpochBasedTrainer
  18. from modelscope.utils.constant import DEFAULT_MODEL_REVISION
  19. from modelscope.utils.logger import get_logger
  20. logger = get_logger()
  21. @dataclass
  22. class SentenceEmbeddingCollator(DataCollatorWithPadding):
  23. """
  24. Wrapper that does conversion from List[Tuple[encode_qry, encode_psg]] to List[qry], List[psg]
  25. and pass batch separately to the actual collator.
  26. Abstract out data detail for the model.
  27. """
  28. max_length = 128
  29. tokenizer = None
  30. def __call__(self, features):
  31. qq = [f['query'] for f in features]
  32. dd = [f['docs'] for f in features]
  33. keys = qq[0].keys()
  34. qq = {k: [ele[k] for ele in qq] for k in keys}
  35. q_collated = self.tokenizer._tokenizer.pad(
  36. qq,
  37. padding='max_length',
  38. max_length=self.max_length,
  39. return_tensors='pt')
  40. keys = dd[0].keys()
  41. dd = {k: sum([ele[k] for ele in dd], []) for k in keys}
  42. d_collated = self.tokenizer._tokenizer.pad(
  43. dd,
  44. padding='max_length',
  45. max_length=self.max_length,
  46. return_tensors='pt')
  47. return {'query': q_collated, 'docs': d_collated}
  48. @TRAINERS.register_module(module_name=Trainers.nlp_sentence_embedding_trainer)
  49. class SentenceEmbeddingTrainer(NlpEpochBasedTrainer):
  50. def __init__(
  51. self,
  52. model: Optional[Union[TorchModel, nn.Module, str]] = None,
  53. cfg_file: Optional[str] = None,
  54. cfg_modify_fn: Optional[Callable] = None,
  55. arg_parse_fn: Optional[Callable] = None,
  56. data_collator: Optional[Callable] = None,
  57. train_dataset: Optional[Union[MsDataset, Dataset]] = None,
  58. eval_dataset: Optional[Union[MsDataset, Dataset]] = None,
  59. preprocessor: Optional[Preprocessor] = None,
  60. optimizers: Tuple[torch.optim.Optimizer,
  61. torch.optim.lr_scheduler._LRScheduler] = (None,
  62. None),
  63. model_revision: Optional[str] = DEFAULT_MODEL_REVISION,
  64. **kwargs):
  65. super().__init__(
  66. model=model,
  67. cfg_file=cfg_file,
  68. cfg_modify_fn=cfg_modify_fn,
  69. arg_parse_fn=arg_parse_fn,
  70. data_collator=data_collator,
  71. preprocessor=preprocessor,
  72. optimizers=optimizers,
  73. train_dataset=train_dataset,
  74. eval_dataset=eval_dataset,
  75. model_revision=model_revision,
  76. **kwargs)
  77. def get_data_collator(self, data_collator, **kwargs):
  78. """Get the data collator for both training and evaluating.
  79. Args:
  80. data_collator: The input data_collator param.
  81. Returns:
  82. The train_data_collator and eval_data_collator, can be None.
  83. """
  84. if data_collator is None:
  85. data_collator = SentenceEmbeddingCollator(
  86. tokenizer=self.train_preprocessor.nlp_tokenizer,
  87. max_length=self.train_preprocessor.max_length)
  88. return super().get_data_collator(data_collator, **kwargs)
  89. def evauate(self):
  90. return {}