base.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import time
  4. from abc import ABC, abstractmethod
  5. from typing import Callable, Dict, Optional
  6. from modelscope.hub.check_model import check_local_model_is_latest
  7. from modelscope.hub.snapshot_download import snapshot_download
  8. from modelscope.trainers.builder import TRAINERS
  9. from modelscope.utils.config import Config
  10. from modelscope.utils.constant import Invoke, ThirdParty
  11. from .utils.log_buffer import LogBuffer
  12. class BaseTrainer(ABC):
  13. """ Base class for trainer which can not be instantiated.
  14. BaseTrainer defines necessary interface
  15. and provide default implementation for basic initialization
  16. such as parsing config file and parsing commandline args.
  17. """
  18. def __init__(self, cfg_file: str, arg_parse_fn: Optional[Callable] = None):
  19. """ Trainer basic init, should be called in derived class
  20. Args:
  21. cfg_file: Path to configuration file.
  22. arg_parse_fn: Same as ``parse_fn`` in :obj:`Config.to_args`.
  23. """
  24. self.cfg = Config.from_file(cfg_file)
  25. if arg_parse_fn:
  26. self.args = self.cfg.to_args(arg_parse_fn)
  27. else:
  28. self.args = None
  29. self.log_buffer = LogBuffer()
  30. self.visualization_buffer = LogBuffer()
  31. self.timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
  32. def get_or_download_model_dir(self,
  33. model,
  34. model_revision=None,
  35. third_party=None):
  36. """ Get local model directory or download model if necessary.
  37. Args:
  38. model (str): model id or path to local model directory.
  39. model_revision (str, optional): model version number.
  40. third_party (str, optional): in which third party library
  41. this function is called.
  42. """
  43. if os.path.exists(model):
  44. model_cache_dir = model if os.path.isdir(
  45. model) else os.path.dirname(model)
  46. check_local_model_is_latest(
  47. model_cache_dir,
  48. user_agent={
  49. Invoke.KEY: Invoke.LOCAL_TRAINER,
  50. ThirdParty.KEY: third_party
  51. })
  52. else:
  53. model_cache_dir = snapshot_download(
  54. model,
  55. revision=model_revision,
  56. user_agent={
  57. Invoke.KEY: Invoke.TRAINER,
  58. ThirdParty.KEY: third_party
  59. })
  60. return model_cache_dir
  61. @abstractmethod
  62. def train(self, *args, **kwargs):
  63. """ Train (and evaluate) process
  64. Train process should be implemented for specific task or
  65. model, related parameters have been initialized in
  66. ``BaseTrainer.__init__`` and should be used in this function
  67. """
  68. pass
  69. @abstractmethod
  70. def evaluate(self, checkpoint_path: str, *args,
  71. **kwargs) -> Dict[str, float]:
  72. """ Evaluation process
  73. Evaluation process should be implemented for specific task or
  74. model, related parameters have been initialized in
  75. ``BaseTrainer.__init__`` and should be used in this function
  76. """
  77. pass
  78. @TRAINERS.register_module(module_name='dummy')
  79. class DummyTrainer(BaseTrainer):
  80. def __init__(self, cfg_file: str, *args, **kwargs):
  81. """ Dummy Trainer.
  82. Args:
  83. cfg_file: Path to configuration file.
  84. """
  85. super().__init__(cfg_file)
  86. def train(self, *args, **kwargs):
  87. """ Train (and evaluate) process
  88. Train process should be implemented for specific task or
  89. model, related parameters have been initialized in
  90. ``BaseTrainer.__init__`` and should be used in this function
  91. """
  92. cfg = self.cfg.train
  93. print(f'train cfg {cfg}')
  94. def evaluate(self,
  95. checkpoint_path: str = None,
  96. *args,
  97. **kwargs) -> Dict[str, float]:
  98. """ Evaluation process
  99. Evaluation process should be implemented for specific task or
  100. model, related parameters have been initialized in
  101. ``BaseTrainer.__init__`` and should be used in this function
  102. """
  103. cfg = self.cfg.evaluation
  104. print(f'eval cfg {cfg}')
  105. print(f'checkpoint_path {checkpoint_path}')