| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- import os
- import time
- from abc import ABC, abstractmethod
- from typing import Callable, Dict, Optional
- from modelscope.hub.check_model import check_local_model_is_latest
- from modelscope.hub.snapshot_download import snapshot_download
- from modelscope.trainers.builder import TRAINERS
- from modelscope.utils.config import Config
- from modelscope.utils.constant import Invoke, ThirdParty
- from .utils.log_buffer import LogBuffer
- class BaseTrainer(ABC):
- """ Base class for trainer which can not be instantiated.
- BaseTrainer defines necessary interface
- and provide default implementation for basic initialization
- such as parsing config file and parsing commandline args.
- """
- def __init__(self, cfg_file: str, arg_parse_fn: Optional[Callable] = None):
- """ Trainer basic init, should be called in derived class
- Args:
- cfg_file: Path to configuration file.
- arg_parse_fn: Same as ``parse_fn`` in :obj:`Config.to_args`.
- """
- self.cfg = Config.from_file(cfg_file)
- if arg_parse_fn:
- self.args = self.cfg.to_args(arg_parse_fn)
- else:
- self.args = None
- self.log_buffer = LogBuffer()
- self.visualization_buffer = LogBuffer()
- self.timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
- def get_or_download_model_dir(self,
- model,
- model_revision=None,
- third_party=None):
- """ Get local model directory or download model if necessary.
- Args:
- model (str): model id or path to local model directory.
- model_revision (str, optional): model version number.
- third_party (str, optional): in which third party library
- this function is called.
- """
- if os.path.exists(model):
- model_cache_dir = model if os.path.isdir(
- model) else os.path.dirname(model)
- check_local_model_is_latest(
- model_cache_dir,
- user_agent={
- Invoke.KEY: Invoke.LOCAL_TRAINER,
- ThirdParty.KEY: third_party
- })
- else:
- model_cache_dir = snapshot_download(
- model,
- revision=model_revision,
- user_agent={
- Invoke.KEY: Invoke.TRAINER,
- ThirdParty.KEY: third_party
- })
- return model_cache_dir
- @abstractmethod
- def train(self, *args, **kwargs):
- """ Train (and evaluate) process
- Train process should be implemented for specific task or
- model, related parameters have been initialized in
- ``BaseTrainer.__init__`` and should be used in this function
- """
- pass
- @abstractmethod
- def evaluate(self, checkpoint_path: str, *args,
- **kwargs) -> Dict[str, float]:
- """ Evaluation process
- Evaluation process should be implemented for specific task or
- model, related parameters have been initialized in
- ``BaseTrainer.__init__`` and should be used in this function
- """
- pass
- @TRAINERS.register_module(module_name='dummy')
- class DummyTrainer(BaseTrainer):
- def __init__(self, cfg_file: str, *args, **kwargs):
- """ Dummy Trainer.
- Args:
- cfg_file: Path to configuration file.
- """
- super().__init__(cfg_file)
- def train(self, *args, **kwargs):
- """ Train (and evaluate) process
- Train process should be implemented for specific task or
- model, related parameters have been initialized in
- ``BaseTrainer.__init__`` and should be used in this function
- """
- cfg = self.cfg.train
- print(f'train cfg {cfg}')
- def evaluate(self,
- checkpoint_path: str = None,
- *args,
- **kwargs) -> Dict[str, float]:
- """ Evaluation process
- Evaluation process should be implemented for specific task or
- model, related parameters have been initialized in
- ``BaseTrainer.__init__`` and should be used in this function
- """
- cfg = self.cfg.evaluation
- print(f'eval cfg {cfg}')
- print(f'checkpoint_path {checkpoint_path}')
|