trainer.py 59 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import inspect
  3. import os
  4. from collections.abc import Mapping
  5. from copy import deepcopy
  6. from distutils.version import LooseVersion
  7. from functools import partial
  8. from typing import Callable, Dict, List, Optional, Tuple, Union
  9. import json
  10. import torch
  11. from torch import distributed as dist
  12. from torch import nn
  13. from torch.utils.data import DataLoader, Dataset, Sampler
  14. from torch.utils.data.dataloader import default_collate
  15. from torch.utils.data.distributed import DistributedSampler
  16. from modelscope.hub.check_model import check_local_model_is_latest
  17. from modelscope.metainfo import Trainers
  18. from modelscope.metrics import build_metric, task_default_metrics
  19. from modelscope.metrics.prediction_saving_wrapper import \
  20. PredictionSavingWrapper
  21. from modelscope.models.base import Model, TorchModel
  22. from modelscope.msdatasets.dataset_cls.custom_datasets import \
  23. TorchCustomDataset
  24. from modelscope.msdatasets.dataset_cls.custom_datasets.builder import \
  25. build_custom_dataset
  26. from modelscope.msdatasets.ms_dataset import MsDataset
  27. from modelscope.outputs import ModelOutputBase
  28. from modelscope.preprocessors.base import Preprocessor
  29. from modelscope.trainers.hooks.builder import HOOKS
  30. from modelscope.trainers.hooks.priority import Priority, get_priority
  31. from modelscope.trainers.lrscheduler.builder import build_lr_scheduler
  32. from modelscope.trainers.optimizer.builder import build_optimizer
  33. from modelscope.utils.config import Config, ConfigDict, JSONIteratorEncoder
  34. from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ConfigFields,
  35. ConfigKeys, DistributedParallelType,
  36. Invoke, ModeKeys, ModelFile, ThirdParty,
  37. TrainerStages)
  38. from modelscope.utils.data_utils import to_device
  39. from modelscope.utils.device import create_device
  40. from modelscope.utils.file_utils import func_receive_dict_inputs
  41. from modelscope.utils.import_utils import is_swift_available
  42. from modelscope.utils.logger import get_logger
  43. from modelscope.utils.registry import build_from_cfg
  44. from modelscope.utils.torch_utils import (compile_model, get_dist_info,
  45. get_local_rank, init_dist, is_dist,
  46. is_master, is_on_same_device,
  47. set_random_seed)
  48. from .base import BaseTrainer
  49. from .builder import TRAINERS
  50. from .default_config import merge_cfg, merge_hooks, update_cfg
  51. from .hooks.hook import Hook
  52. from .parallel.builder import build_parallel
  53. from .parallel.utils import is_parallel
  54. TunerConfig = Union['swift.SwiftConfig', 'swift.PeftConfig']
  55. @TRAINERS.register_module(module_name=Trainers.default)
  56. class EpochBasedTrainer(BaseTrainer):
  57. """Epoch based Trainer, a training helper for PyTorch.
  58. Args:
  59. cfg_file(str): The local config file.
  60. model (:obj:`torch.nn.Module` or :obj:`TorchModel` or `str`): The model to be run, or a valid model dir
  61. or a model id. If model is None, build_model method will be called.
  62. data_collator (`Callable`, *optional*):
  63. The function to use to form a batch from a list of elements of `train_dataset` or `eval_dataset`.
  64. train_dataset (`MsDataset` or `torch.utils.data.Dataset`, *optional*):
  65. The dataset to use for training.
  66. Note that if it's a `torch.utils.data.IterableDataset` with some randomization and you are training in a
  67. distributed fashion, your iterable dataset should either use a internal attribute `generator` that is a
  68. `torch.Generator` for the randomization that must be identical on all processes (and the Trainer will
  69. manually set the seed of this `generator` at each epoch) or have a `set_epoch()` method that internally
  70. sets the seed of the RNGs used.
  71. eval_dataset (`MsDataset` or `torch.utils.data.Dataset`, *optional*): The dataset to use for evaluation.
  72. preprocessor (:obj:`Preprocessor`, *optional*): The optional preprocessor.
  73. NOTE: If the preprocessor has been called before the dataset fed into this trainer by user's custom code,
  74. this parameter should be None, meanwhile remove the 'preprocessor' key from the cfg_file.
  75. Else the preprocessor will be instantiated from the cfg_file or assigned from this parameter and
  76. this preprocessing action will be executed every time the dataset's __getitem__ is called.
  77. optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler._LRScheduler]`, *optional*): A tuple
  78. containing the optimizer and the scheduler to use.
  79. seed (int): The optional random seed for torch, cuda, numpy and random.
  80. max_epochs: (int, optional): Total training epochs.
  81. cfg_modify_fn: An input fn which is used to modify the cfg read out of the file.
  82. remove_unused_data: Automatically remove unused data keys in mini-batches.
  83. The remove action based on the `inspect` on the model's forward method, the removed columns will be
  84. moved to the mini-batch's attributes.
  85. compile (bool, optional): Compile the model with torch 2.0, default False
  86. compile_options (dict, optional): The compile options if compile=True,
  87. default None to use the default params of 'TorchModel.compile'.
  88. efficient_tuners (dict, optional): The tuners to use to train the model
  89. samplers: (:obj:`Sampler` or `Dict[Sampler]`, *optional*): samplers used in the train/eval DataLoader.
  90. Examples of cfg_modify_fn:
  91. >>> def cfg_modify_fn(cfg):
  92. >>> cfg.preprocessor.first_sequence= 'text1'
  93. >>> cfg.preprocessor.second_sequence='text2'
  94. >>> return cfg
  95. """
  96. def __init__(
  97. self,
  98. model: Optional[Union[TorchModel, nn.Module, str]] = None,
  99. cfg_file: Optional[str] = None,
  100. cfg_modify_fn: Optional[Callable] = None,
  101. arg_parse_fn: Optional[Callable] = None,
  102. data_collator: Optional[Union[Callable, Dict[str,
  103. Callable]]] = None,
  104. train_dataset: Optional[Union[MsDataset, Dataset]] = None,
  105. eval_dataset: Optional[Union[MsDataset, Dataset]] = None,
  106. preprocessor: Optional[Union[Preprocessor,
  107. Dict[str, Preprocessor]]] = None,
  108. optimizers: Tuple[torch.optim.Optimizer,
  109. torch.optim.lr_scheduler._LRScheduler] = (None,
  110. None),
  111. model_revision: Optional[str] = DEFAULT_MODEL_REVISION,
  112. seed: int = 42,
  113. callbacks: Optional[List[Hook]] = None,
  114. samplers: Optional[Union[Sampler, Dict[str, Sampler]]] = None,
  115. efficient_tuners: Union[Dict[str, TunerConfig],
  116. TunerConfig] = None,
  117. **kwargs):
  118. self._seed = seed
  119. set_random_seed(self._seed)
  120. self._metric_values = None
  121. self.optimizers = optimizers
  122. self._mode = ModeKeys.TRAIN
  123. self._hooks: List[Hook] = []
  124. self._epoch = 0
  125. self._iter = 0
  126. self._inner_iter = 0
  127. self._stop_training = False
  128. self._compile = kwargs.get('compile', False)
  129. self.train_dataloader = None
  130. self.eval_dataloader = None
  131. self.data_loader = None
  132. self._samplers = samplers
  133. if isinstance(model, str):
  134. self.model_dir = self.get_or_download_model_dir(
  135. model, model_revision, kwargs.pop(ThirdParty.KEY, None))
  136. if cfg_file is None:
  137. cfg_file = os.path.join(self.model_dir,
  138. ModelFile.CONFIGURATION)
  139. self.input_model_id = model
  140. else:
  141. assert cfg_file is not None, 'Config file should not be None if model is not from pretrained!'
  142. self.model_dir = os.path.dirname(cfg_file)
  143. self.input_model_id = None
  144. if hasattr(model, 'model_dir'):
  145. check_local_model_is_latest(
  146. model.model_dir,
  147. user_agent={
  148. Invoke.KEY: Invoke.LOCAL_TRAINER,
  149. ThirdParty.KEY: kwargs.pop(ThirdParty.KEY, None)
  150. })
  151. super().__init__(cfg_file, arg_parse_fn)
  152. self.cfg_modify_fn = cfg_modify_fn
  153. # add default config
  154. merge_cfg(self.cfg)
  155. self.cfg = self.rebuild_config(self.cfg)
  156. if 'cfg_options' in kwargs:
  157. self.cfg.merge_from_dict(kwargs['cfg_options'])
  158. self.cfg = update_cfg(self.cfg)
  159. if isinstance(model, (TorchModel, nn.Module)):
  160. self.model = model
  161. else:
  162. self.model = self.build_model()
  163. if self._compile:
  164. # Compile the model with torch 2.0
  165. compile_options = kwargs.get('compile_options')
  166. if compile_options is None:
  167. compile_options = {}
  168. self.model = compile_model(self.model, **compile_options)
  169. if kwargs.get('work_dir', None) is not None:
  170. self.work_dir = kwargs['work_dir']
  171. if 'train' not in self.cfg:
  172. self.cfg['train'] = ConfigDict()
  173. self.cfg['train']['work_dir'] = self.work_dir
  174. if 'checkpoint' in self.cfg['train']:
  175. if 'period' in self.cfg['train']['checkpoint']:
  176. self.cfg['train']['checkpoint']['period'][
  177. 'save_dir'] = self.work_dir
  178. if 'best' in self.cfg['train']['checkpoint']:
  179. self.cfg['train']['checkpoint']['best'][
  180. 'save_dir'] = self.work_dir
  181. if 'logging' in self.cfg['train']:
  182. self.cfg['train']['logging']['out_dir'] = self.work_dir
  183. else:
  184. self.work_dir = self.cfg.train.get('work_dir', './work_dir')
  185. self.train_preprocessor, self.eval_preprocessor = self.get_preprocessors(
  186. preprocessor)
  187. if not os.path.exists(self.work_dir):
  188. # TODO duplicate makedirs may cause errors in dlc envs.
  189. os.makedirs(self.work_dir, exist_ok=True)
  190. # init logger after distribution init
  191. log_file = os.path.join(self.work_dir, '{}.log'.format(self.timestamp))
  192. self.logger = get_logger(
  193. log_file=log_file, log_level=self.cfg.get('log_level', 'INFO'))
  194. # Get train datasets
  195. self.train_dataset = self.build_dataset(
  196. datasets=train_dataset,
  197. model_cfg=self.cfg,
  198. mode=ModeKeys.TRAIN,
  199. preprocessor=self.train_preprocessor,
  200. **kwargs)
  201. # Get evaluation datasets
  202. self.eval_dataset = self.build_dataset(
  203. datasets=eval_dataset,
  204. model_cfg=self.cfg,
  205. mode=ModeKeys.EVAL,
  206. preprocessor=self.eval_preprocessor,
  207. **kwargs)
  208. self.train_data_collator, self.eval_data_collator = self.get_data_collator(
  209. data_collator,
  210. remove_unused_data=kwargs.get('remove_unused_data', False))
  211. self._max_epochs = kwargs.get('max_epochs',
  212. self.cfg.safe_get('train.max_epochs'))
  213. assert self._max_epochs is not None, 'max_epochs should be provided by the init arguments or configured ' \
  214. 'in the `train.max_epochs` key in the configuration file.'
  215. self._train_iters_per_epoch = kwargs.get(
  216. 'train_iters_per_epoch',
  217. self.cfg.safe_get('train.train_iters_per_epoch'))
  218. self._eval_iters_per_epoch = kwargs.get(
  219. 'val_iters_per_epoch',
  220. self.cfg.safe_get('evaluation.val_iters_per_epoch'))
  221. self.use_fp16 = kwargs.get('use_fp16', False)
  222. self.launcher = kwargs.get('launcher')
  223. self.device = kwargs.get('device')
  224. self.tune_module(efficient_tuners)
  225. # The parallel_groups field will be initialized in the hooks' after_init stage.
  226. # Please check the DDPHook and MegatronHook for details.
  227. self.parallel_groups = {}
  228. if self.launcher is not None and not self.cfg.safe_get(
  229. 'train.hooks.DDPHook'):
  230. # A logic to fit the current code
  231. # Put a DDPHook in if launcher is provided.
  232. if 'hooks' not in self.cfg.train:
  233. self.cfg.train['hooks'] = []
  234. self.cfg.train['hooks'].append({
  235. 'type': 'DDPHook',
  236. 'launcher': self.launcher
  237. })
  238. hooks = merge_hooks(self.cfg)
  239. self.register_hook_from_cfg(hooks)
  240. # Add user callback to hooks
  241. if callable(callbacks):
  242. callbacks = [callbacks]
  243. for callback in callbacks or []:
  244. self.register_hook(callback)
  245. self.invoke_hook(TrainerStages.after_init)
  246. # _dist represents for if dp is initialized and its world_size > 1
  247. self._dist = self.is_dp_group_available() and dist.get_world_size(
  248. self.dp_group) > 1
  249. self.metrics = self.get_metrics()
  250. if not self.parallel_groups:
  251. # If not working in parallel scenario, put model to device as a default logic.
  252. device_name = self.device if self.device is not None else 'gpu'
  253. self.device = create_device(device_name)
  254. if self.device.type == 'cuda' and is_on_same_device(self.model):
  255. self.model.to(self.device)
  256. self.print_cfg()
  257. def tune_module(self, efficient_tuners):
  258. if efficient_tuners is not None:
  259. if not is_swift_available():
  260. raise ValueError(
  261. 'Please install swift by `pip install ms-swift` to use efficient_tuners.'
  262. )
  263. from swift import Swift
  264. self.model = Swift.prepare_model(self.model, efficient_tuners)
  265. def place_model(self):
  266. """Place model to device, or to DDP
  267. """
  268. if self.device.type == 'cuda':
  269. self.model.to(self.device)
  270. if not is_parallel(self.model) and self._dist:
  271. self.model = self.to_parallel(self.model)
  272. def get_data_collator(self, data_collator, remove_unused_data=False):
  273. """Get the data collator for both training and evaluating.
  274. Args:
  275. data_collator: The input data_collator param.
  276. remove_unused_data: Remove the unused data with 'RemoveColumnsCollator'.
  277. Returns:
  278. The train_data_collator and eval_data_collator, can be None.
  279. """
  280. train_data_collator, eval_data_collator = None, None
  281. if isinstance(data_collator, Mapping):
  282. if ConfigKeys.train in data_collator:
  283. assert isinstance(data_collator[ConfigKeys.train], Callable)
  284. train_data_collator = data_collator[ConfigKeys.train]
  285. if ConfigKeys.val in data_collator:
  286. assert isinstance(data_collator[ConfigKeys.val], Callable)
  287. eval_data_collator = data_collator[ConfigKeys.val]
  288. else:
  289. collate_fn = default_collate if data_collator is None else data_collator
  290. train_data_collator = collate_fn
  291. eval_data_collator = collate_fn
  292. if remove_unused_data:
  293. from modelscope.utils.data_collators import RemoveColumnsCollator
  294. def _set_signature_columns_if_needed():
  295. signature = inspect.signature(self.model.forward)
  296. return list(signature.parameters.keys())
  297. model_inputs = _set_signature_columns_if_needed()
  298. train_data_collator = RemoveColumnsCollator(
  299. train_data_collator, model_inputs)
  300. eval_data_collator = RemoveColumnsCollator(eval_data_collator,
  301. model_inputs)
  302. return train_data_collator, eval_data_collator
  303. def init_dist(self, launcher=None):
  304. """Init dist and returns the dist information.
  305. Args:
  306. launcher: The launcher info.
  307. Returns:
  308. _dist: If world_size is greater than 1.
  309. """
  310. if launcher is not None:
  311. init_dist(launcher)
  312. _, world_size = get_dist_info()
  313. _dist = world_size > 1
  314. return _dist
  315. def get_device(self, device=None):
  316. """Get the device information.
  317. Args:
  318. device: The input device info.
  319. Returns:
  320. device_name: The final device name.
  321. """
  322. device_name = device if device is not None else 'gpu'
  323. if is_dist():
  324. local_rank = get_local_rank()
  325. device_name = f'cuda:{local_rank}'
  326. return create_device(device_name)
  327. def get_preprocessors(self, preprocessor):
  328. """Get the preprocessors information.
  329. Args:
  330. preprocessor: The input preprocessor info.
  331. Returns:
  332. The train_preprocessor and eval_preprocessor, can be None.
  333. """
  334. train_preprocessor = None
  335. eval_preprocessor = None
  336. if isinstance(preprocessor, Preprocessor):
  337. train_preprocessor = preprocessor
  338. eval_preprocessor = preprocessor
  339. elif isinstance(preprocessor, Mapping):
  340. if ConfigKeys.train in preprocessor:
  341. assert isinstance(preprocessor[ConfigKeys.train], Callable)
  342. train_preprocessor = preprocessor[ConfigKeys.train]
  343. if ConfigKeys.val in preprocessor:
  344. assert isinstance(preprocessor[ConfigKeys.val], Callable)
  345. eval_preprocessor = preprocessor[ConfigKeys.val]
  346. elif hasattr(self.cfg, ConfigFields.preprocessor
  347. ) and self.cfg.preprocessor is not None:
  348. train_preprocessor, eval_preprocessor = self.build_preprocessor()
  349. if train_preprocessor is not None:
  350. train_preprocessor.mode = ModeKeys.TRAIN
  351. if eval_preprocessor is not None:
  352. eval_preprocessor.mode = ModeKeys.EVAL
  353. return train_preprocessor, eval_preprocessor
  354. def rebuild_config(self, cfg: Config):
  355. """A method used to rebuild the config, any subclass can override this method.
  356. Returns: The rebuilt config
  357. """
  358. if hasattr(self, 'cfg_modify_fn') and self.cfg_modify_fn is not None:
  359. cfg = self.cfg_modify_fn(cfg)
  360. return cfg
  361. @property
  362. def dp_group(self):
  363. """
  364. Get the data parallel group.
  365. """
  366. return self.parallel_groups[DistributedParallelType.DP]
  367. @property
  368. def tp_group(self):
  369. """
  370. Get the tensor parallel group.
  371. """
  372. return self.parallel_groups[DistributedParallelType.TP]
  373. @property
  374. def pp_group(self):
  375. """
  376. Get the pipeline parallel group.
  377. """
  378. return self.parallel_groups[DistributedParallelType.PP]
  379. def is_dp_group_available(self):
  380. """
  381. Get whether the data parallel group is initialized.
  382. """
  383. return DistributedParallelType.DP in self.parallel_groups
  384. def is_tp_group_available(self):
  385. """
  386. Get whether the tensor parallel group is initialized.
  387. """
  388. return DistributedParallelType.TP in self.parallel_groups
  389. def is_pp_group_available(self):
  390. """
  391. Get whether the pipeline parallel group is initialized.
  392. """
  393. return DistributedParallelType.PP in self.parallel_groups
  394. @property
  395. def mode(self):
  396. return self._mode
  397. @property
  398. def hooks(self) -> List[Hook]:
  399. """list[:obj:`Hook`]: A list of registered hooks."""
  400. return self._hooks
  401. @property
  402. def epoch(self) -> int:
  403. """int: Current epoch."""
  404. return self._epoch
  405. @property
  406. def iter(self) -> int:
  407. """int: Current iteration."""
  408. return self._iter
  409. @property
  410. def inner_iter(self) -> int:
  411. """int: Iteration in an epoch."""
  412. return self._inner_iter
  413. @property
  414. def max_epochs(self):
  415. """int: Maximum training epochs."""
  416. return self._max_epochs
  417. @property
  418. def max_iters(self):
  419. """int: Maximum training iterations."""
  420. return self._max_epochs * self.iters_per_epoch
  421. @property
  422. def iters_per_epoch(self):
  423. """int: Total iterations of one epoch"""
  424. def _get_data_len(data_loader):
  425. try:
  426. return len(data_loader)
  427. except Exception as e:
  428. self.logger.error(e)
  429. raise ValueError(
  430. 'Please implement ``__len__`` method for your dataset, '
  431. 'or add `train_iters_per_epoch` and `train_iters_per_epoch` '
  432. 'to your configuration file or kwargs')
  433. if self.mode == ModeKeys.TRAIN:
  434. if self._train_iters_per_epoch is not None:
  435. return self._train_iters_per_epoch
  436. else:
  437. return _get_data_len(self.train_dataloader)
  438. elif self.mode == ModeKeys.EVAL:
  439. if self._eval_iters_per_epoch is not None:
  440. return self._eval_iters_per_epoch
  441. else:
  442. return _get_data_len(self.eval_dataloader)
  443. def build_dataset(self,
  444. datasets: Union[Dataset, MsDataset, List[Dataset]],
  445. model_cfg: Config,
  446. mode: str,
  447. preprocessor: Optional[Preprocessor] = None,
  448. **kwargs):
  449. """Build input datasets by given model configuration and preprocessor.
  450. Args:
  451. datasets (Union[Dataset, MsDataset, List[Dataset]]): The input datasets.
  452. model_cfg (Config): The model configuration.
  453. mode (str): `train`, `eval` or `inference`. See modelscope.utils.constant.ModeKeys
  454. preprocessor (Preprocessor, Optional): The preprocessor for input data samples.
  455. Returns:
  456. Preprocessed datasets.
  457. """
  458. try:
  459. if not datasets:
  460. return EpochBasedTrainer.build_dataset_from_cfg(
  461. model_cfg=model_cfg, mode=mode, preprocessor=preprocessor)
  462. if isinstance(datasets, TorchCustomDataset):
  463. return datasets
  464. elif isinstance(datasets, MsDataset):
  465. if not datasets.is_custom:
  466. datasets.to_custom_dataset(
  467. custom_cfg=model_cfg,
  468. preprocessor=preprocessor,
  469. mode=mode,
  470. **kwargs)
  471. return datasets.ds_instance
  472. elif isinstance(datasets, List) and isinstance(
  473. datasets[0], MsDataset):
  474. custom_datasets = []
  475. for dataset in datasets:
  476. if not dataset.is_custom:
  477. dataset.to_custom_dataset(
  478. custom_cfg=model_cfg,
  479. preprocessor=preprocessor,
  480. mode=mode,
  481. **kwargs)
  482. custom_datasets.append(dataset.ds_instance)
  483. torch_custom_dataset = TorchCustomDataset(
  484. datasets=custom_datasets,
  485. mode=mode,
  486. preprocessor=None,
  487. **kwargs)
  488. torch_custom_dataset.trainer = self
  489. return torch_custom_dataset
  490. else:
  491. dataset_mode_key = 'train' if mode == ModeKeys.TRAIN else 'val'
  492. data_config = model_cfg.safe_get(f'dataset.{dataset_mode_key}')
  493. if data_config is None:
  494. # adapt to some special models
  495. data_config = {}
  496. # avoid add no str value datasets, preprocessors in cfg
  497. data_build_config = ConfigDict(
  498. type=model_cfg.model.type,
  499. mode=mode,
  500. datasets=datasets,
  501. preprocessor=preprocessor)
  502. data_build_config.update(data_config)
  503. custom_dataset = build_custom_dataset(data_build_config,
  504. model_cfg.task)
  505. custom_dataset.trainer = self
  506. return custom_dataset
  507. except Exception as e:
  508. print('** build_dataset error log:', e)
  509. if isinstance(datasets, (List, Tuple)) or preprocessor is not None:
  510. custom_dataset = TorchCustomDataset(
  511. datasets,
  512. mode=mode,
  513. preprocessor=preprocessor,
  514. **(dict(type=model_cfg.model.type) if hasattr(
  515. model_cfg, 'model') else {}))
  516. custom_dataset.trainer = self
  517. return custom_dataset
  518. else:
  519. return datasets
  520. def to_task_dataset(self, dataset: Dataset, mode: str,
  521. preprocessor: Preprocessor,
  522. **kwargs) -> TorchCustomDataset:
  523. r"""
  524. @deprecated
  525. This method is deprecated and may be removed in future releases, please use `build_dataset()` instead. Could be
  526. compatible with methods that override the to_task_dataset in other classes.
  527. """
  528. self.logger.warning(
  529. 'This to_task_dataset method is deprecated, please use build_dataset instead.'
  530. )
  531. task_dataset = TorchCustomDataset(
  532. dataset, mode=mode, preprocessor=preprocessor, **kwargs)
  533. task_dataset.trainer = self
  534. return task_dataset
  535. @staticmethod
  536. def build_dataset_from_cfg(model_cfg: Config,
  537. mode: str,
  538. preprocessor: Preprocessor = None):
  539. dataset = None
  540. dataset_name = model_cfg.safe_get('dataset.name')
  541. subset_name = model_cfg.safe_get('dataset.subset', default='default')
  542. split_name = model_cfg.safe_get(f'dataset.split_{mode}')
  543. if not dataset_name or not split_name:
  544. return dataset
  545. dataset = MsDataset.load(
  546. dataset_name=dataset_name,
  547. subset_name=subset_name,
  548. split=split_name,
  549. custom_cfg=model_cfg)
  550. if not dataset.is_custom:
  551. dataset.to_custom_dataset(
  552. custom_cfg=model_cfg, preprocessor=preprocessor, mode=mode)
  553. return dataset.ds_instance
  554. def build_preprocessor(self) -> Tuple[Preprocessor, Preprocessor]:
  555. """Build train and eval preprocessor.
  556. User can override this method to implement custom logits.
  557. Returns: The train preprocessor and eval preprocessor instance.
  558. """
  559. train_preprocessor = Preprocessor.from_pretrained(
  560. self.model_dir,
  561. cfg_dict=self.cfg,
  562. preprocessor_mode=ModeKeys.TRAIN)
  563. eval_preprocessor = Preprocessor.from_pretrained(
  564. self.model_dir, cfg_dict=self.cfg, preprocessor_mode=ModeKeys.EVAL)
  565. return train_preprocessor, eval_preprocessor
  566. def get_metrics(self) -> List[Union[str, Dict]]:
  567. """Get the metric class types.
  568. The first choice will be the metrics configured in the config file, if not found, the default metrics will be
  569. used.
  570. If no metrics is found and the eval dataset exists, the method will raise an error.
  571. Returns: The metric types.
  572. """
  573. metrics = self.cfg.evaluation.metrics if hasattr(
  574. self.cfg, 'evaluation') and hasattr(self.cfg.evaluation,
  575. 'metrics') else None
  576. metrics = metrics if metrics is not None else task_default_metrics.get(
  577. self.cfg.task)
  578. if metrics is None and self.eval_dataset is not None:
  579. raise ValueError(
  580. f'Metrics are needed in evaluation, please try to either '
  581. f'add metrics in configuration.json or add the default metric for {self.cfg.task}.'
  582. )
  583. if isinstance(metrics, (str, Mapping)):
  584. metrics = [metrics]
  585. return metrics
  586. def set_checkpoint_file_to_hook(self, checkpoint_path, load_all_state,
  587. strict):
  588. if checkpoint_path is not None:
  589. from modelscope.trainers.hooks import LoadCheckpointHook
  590. load_ckpt_hooks = list(
  591. filter(lambda hook: isinstance(hook, LoadCheckpointHook),
  592. self.hooks))
  593. if len(load_ckpt_hooks) == 0:
  594. load_ckpt_hook = LoadCheckpointHook()
  595. self.register_hook(load_ckpt_hook)
  596. load_ckpt_hooks.append(load_ckpt_hook)
  597. load_ckpt_hooks[0].checkpoint_file = checkpoint_path
  598. load_ckpt_hooks[0].load_all_state = load_all_state
  599. load_ckpt_hooks[0].strict = strict
  600. def train(self,
  601. checkpoint_path=None,
  602. load_all_state=True,
  603. *args,
  604. **kwargs):
  605. """Start training.
  606. Args:
  607. checkpoint_path(`str`, `optional`): The previous saving checkpoint to read,
  608. usually it's a `some-file-name.pth` file generated by this trainer.
  609. load_all_state(`bool`: `optional`): Load all state out of the `checkpoint_path` file, including the
  610. state dict of model, optimizer, lr_scheduler, the random state and epoch/iter number. If False, only
  611. the model's state dict will be read, and model will be trained again.
  612. kwargs:
  613. strict(`boolean`): If strict, any unmatched keys will cause an error.
  614. """
  615. self._mode = ModeKeys.TRAIN
  616. self.train_dataloader = self.get_train_dataloader()
  617. self.data_loader = self.train_dataloader
  618. self.register_optimizers_hook()
  619. self.register_processors()
  620. self.print_hook_info()
  621. self.set_checkpoint_file_to_hook(checkpoint_path, load_all_state,
  622. kwargs.get('strict', False))
  623. self.model.train()
  624. self.train_loop(self.train_dataloader)
  625. def predict(self,
  626. predict_datasets: Union[Dataset, List[Dataset]],
  627. saving_fn,
  628. checkpoint_path=None,
  629. strict=False):
  630. """Start prediction.
  631. Args:
  632. predict_datasets(Union[Dataset, List[Dataset]]): The datasets used to predict ground truth.
  633. saving_fn(`Callable`): The callable used to save the prediction values to files. Like:
  634. >>> class SavingFn:
  635. >>> def __init__(self):
  636. >>> self.filename = '/tmp/results.txt'
  637. >>>
  638. >>> def __call__(self, inputs, outputs):
  639. >>> import numpy as np
  640. >>> ids = inputs.ids
  641. >>> predictions = np.argmax(outputs['logits'].cpu().numpy(), axis=1)
  642. >>> with open(self.filename, 'a') as f:
  643. >>> for id, pred in zip(ids, predictions):
  644. >>> f.writelines(f'{id}, {pred}')
  645. This saving_fn's result will not be collected to one file, Training with multiprocessing please
  646. consider combining these files manually.
  647. checkpoint_path(`str`, `optional`): The previous saving checkpoint to read,
  648. usually it's a `some-file-name.pth` file or a pure PyTorch `some-file.bin` file
  649. generated by this trainer.
  650. strict(`boolean`): If strict, any unmatched keys will cause an error.
  651. """
  652. self.register_processors()
  653. self.print_hook_info()
  654. if checkpoint_path is not None:
  655. from modelscope.trainers.hooks import LoadCheckpointHook
  656. LoadCheckpointHook.load_checkpoint(
  657. checkpoint_path, self, strict=strict)
  658. self.model.eval()
  659. self._mode = ModeKeys.EVAL
  660. predict_dataloader = self.get_predict_dataloader(predict_datasets)
  661. metric_classes = [PredictionSavingWrapper(saving_fn=saving_fn)]
  662. for m in metric_classes:
  663. m.trainer = self
  664. self.evaluation_loop(predict_dataloader, metric_classes)
  665. def evaluate(self, checkpoint_path=None, saving_fn=None, **kwargs):
  666. """Start evaluation.
  667. Args:
  668. checkpoint_path(`str`, `optional`): The previous saving checkpoint to read,
  669. usually it's a `some-file-name.pth` file or a pure PyTorch `some-file.bin` file
  670. generated by this trainer.
  671. saving_fn(`Callable`): The callable used to save the prediction values to files. Like:
  672. >>> class SavingFn:
  673. >>> def __init__(self):
  674. >>> self.filename = '/tmp/results.txt'
  675. >>>
  676. >>> def __call__(self, inputs, outputs):
  677. >>> import numpy as np
  678. >>> ids = inputs.ids
  679. >>> predictions = np.argmax(outputs['logits'].cpu().numpy(), axis=1)
  680. >>> with open(self.filename, 'a') as f:
  681. >>> for id, pred in zip(ids, predictions):
  682. >>> f.writelines(f'{id}, {pred}')
  683. kwargs:
  684. strict(`boolean`): If strict, any unmatched keys will cause an error.
  685. """
  686. self.register_processors()
  687. self.print_hook_info()
  688. if checkpoint_path is not None:
  689. from modelscope.trainers.hooks import LoadCheckpointHook
  690. LoadCheckpointHook.load_checkpoint(
  691. checkpoint_path, self, strict=kwargs.get('strict', False))
  692. self.model.eval()
  693. self._mode = ModeKeys.EVAL
  694. self.eval_dataloader = self.get_eval_data_loader()
  695. self.data_loader = self.eval_dataloader
  696. metric_classes = [build_metric(metric) for metric in self.metrics]
  697. if saving_fn is not None:
  698. metric_classes.append(PredictionSavingWrapper(saving_fn=saving_fn))
  699. for m in metric_classes:
  700. m.trainer = self
  701. metric_values = self.evaluation_loop(self.eval_dataloader,
  702. metric_classes)
  703. self._metric_values = metric_values
  704. return metric_values
  705. @property
  706. def metric_values(self):
  707. return self._metric_values
  708. def build_model(self) -> Union[nn.Module, TorchModel]:
  709. """ Instantiate a pytorch model and return.
  710. By default, we will create a model using config from configuration file. You can
  711. override this method in a subclass.
  712. """
  713. model = Model.from_pretrained(self.model_dir, cfg_dict=self.cfg)
  714. if not isinstance(model, nn.Module) and hasattr(model, 'model'):
  715. return model.model
  716. elif isinstance(model, nn.Module):
  717. return model
  718. def to_parallel(self, model) -> Union[nn.Module, TorchModel]:
  719. # config format to reserve custom ddp
  720. if self.cfg.get('parallel', None) is not None:
  721. dp_cfg = deepcopy(self.cfg['parallel'])
  722. dp_cfg.update(
  723. dict(module=model, device_ids=[torch.cuda.current_device()]))
  724. return build_parallel(dp_cfg)
  725. dp_cfg = dict(
  726. type='DistributedDataParallel',
  727. module=model,
  728. find_unused_parameters=True,
  729. device_ids=[torch.cuda.current_device()],
  730. process_group=self.dp_group)
  731. return build_parallel(dp_cfg)
  732. def unwrap_module(self, model) -> Union[nn.Module, TorchModel]:
  733. """Unwrap the model until it's a naked nn.Module.
  734. Args:
  735. model: An module.
  736. """
  737. if hasattr(model, 'module'):
  738. return self.unwrap_module(model.module)
  739. else:
  740. assert isinstance(model, torch.nn.Module)
  741. return model
  742. def train_step(self, model, inputs):
  743. """ Perform a training step on a batch of inputs.
  744. Subclass and override to inject custom behavior.
  745. Args:
  746. model (`TorchModel`): The model to train.
  747. inputs (`Dict[str, Union[torch.Tensor, Any]]`):
  748. The inputs and targets of the model.
  749. The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
  750. argument `labels`. Check your model's documentation for all accepted arguments.
  751. Return:
  752. `torch.Tensor`: The tensor with training loss on this batch.
  753. """
  754. # EvaluationHook will do evaluate and change mode to val, return to train mode
  755. # TODO: find more pretty way to change mode
  756. model.train()
  757. self._mode = ModeKeys.TRAIN
  758. # call model forward but not __call__ to skip postprocess
  759. receive_dict_inputs = func_receive_dict_inputs(
  760. self.unwrap_module(self.model).forward)
  761. if isinstance(inputs, Mapping) and not receive_dict_inputs:
  762. train_outputs = model.forward(**inputs)
  763. else:
  764. train_outputs = model.forward(inputs)
  765. if isinstance(train_outputs, ModelOutputBase):
  766. train_outputs = train_outputs.to_dict()
  767. if not isinstance(train_outputs, dict):
  768. raise TypeError('"model.forward()" must return a dict')
  769. # add model output info to log
  770. if 'log_vars' not in train_outputs:
  771. default_keys_pattern = ['loss']
  772. match_keys = set([])
  773. for key_p in default_keys_pattern:
  774. match_keys.update(
  775. [key for key in train_outputs.keys() if key_p in key])
  776. log_vars = {}
  777. for key in match_keys:
  778. value = train_outputs.get(key, None)
  779. if value is not None:
  780. if is_dist():
  781. value = value.data.clone().to('cuda')
  782. dist.all_reduce(value.div_(dist.get_world_size()))
  783. log_vars.update({key: value.item()})
  784. self.log_buffer.update(log_vars)
  785. else:
  786. self.log_buffer.update(train_outputs['log_vars'])
  787. self.train_outputs = train_outputs
  788. def prediction_step(self, model, inputs):
  789. """Deprecated method
  790. """
  791. self.logger.warn('This prediction_step method is deprecated.')
  792. raise NotImplementedError
  793. def get_train_dataloader(self):
  794. """ Builder torch dataloader for training.
  795. We provide a reasonable default that works well. If you want to use something else, you can change
  796. the config for data.train in configuration file, or subclass and override this method
  797. (or `get_train_dataloader` in a subclass.
  798. """
  799. if self.train_dataset is None:
  800. raise 'The train_dataset cannot be None.'
  801. sampler_cfg = {}
  802. if self._samplers is not None:
  803. sampler_cfg['sampler'] = self._samplers[
  804. ConfigKeys.train] if isinstance(self._samplers,
  805. dict) else self._samplers
  806. data_loader = self._build_dataloader_with_dataset(
  807. self.train_dataset,
  808. dist=self._dist,
  809. seed=self._seed,
  810. collate_fn=self.train_data_collator,
  811. **sampler_cfg,
  812. **self.cfg.train.get('dataloader', {}))
  813. return data_loader
  814. def get_eval_data_loader(self):
  815. """ Builder torch dataloader for evaluation.
  816. We provide a reasonable default that works well. If you want to use something else, you can change
  817. the config for dataset.eval in configuration file, or subclass and override this method in a subclass.
  818. pass
  819. """
  820. if self.eval_dataset is None:
  821. raise 'The eval_dataset cannot be None.'
  822. sampler_cfg = {}
  823. if self._samplers is not None:
  824. sampler_cfg['sampler'] = self._samplers[
  825. ConfigKeys.val] if isinstance(self._samplers,
  826. dict) else self._samplers
  827. default_config = {'shuffle': False}
  828. default_config.update(self.cfg.evaluation.get('dataloader', {}))
  829. data_loader = self._build_dataloader_with_dataset(
  830. self.eval_dataset,
  831. dist=self._dist,
  832. seed=self._seed,
  833. collate_fn=self.eval_data_collator,
  834. **sampler_cfg,
  835. **default_config)
  836. return data_loader
  837. def get_predict_dataloader(self, predict_datasets: Union[Dataset,
  838. List[Dataset]]):
  839. """ Builder torch dataloader for prediction with the config of evaluation.
  840. Args:
  841. predict_datasets(Union[Dataset, List[Dataset]]): The datasets used to predict ground truth.
  842. """
  843. dataset = self.build_dataset(
  844. datasets=predict_datasets,
  845. model_cfg=self.cfg,
  846. mode=ModeKeys.EVAL,
  847. preprocessor=self.eval_preprocessor)
  848. sampler_cfg = {}
  849. if self._samplers is not None:
  850. sampler_cfg['sampler'] = self._samplers[
  851. ConfigKeys.val] if isinstance(self._samplers,
  852. dict) else self._samplers
  853. default_config = {'shuffle': False}
  854. default_config.update(self.cfg.evaluation.get('dataloader', {}))
  855. data_loader = self._build_dataloader_with_dataset(
  856. dataset,
  857. dist=self._dist,
  858. seed=self._seed,
  859. collate_fn=self.eval_data_collator,
  860. **sampler_cfg,
  861. **default_config)
  862. return data_loader
  863. def build_optimizer(self, cfg: ConfigDict, default_args: dict = None):
  864. try:
  865. return build_optimizer(
  866. self.unwrap_module(self.model),
  867. cfg=cfg,
  868. default_args=default_args)
  869. except KeyError as e:
  870. self.logger.error(
  871. f'Build optimizer error, the optimizer {cfg} is a torch native component, '
  872. f'please check if your torch with version: {torch.__version__} matches the config.'
  873. )
  874. raise e
  875. def build_lr_scheduler(self, cfg: ConfigDict, default_args: dict = None):
  876. try:
  877. return build_lr_scheduler(cfg=cfg, default_args=default_args)
  878. except KeyError as e:
  879. self.logger.error(
  880. f'Build lr_scheduler error, the lr_scheduler {cfg} is a torch native component, '
  881. f'please check if your torch with version: {torch.__version__} matches the config.'
  882. )
  883. raise e
  884. def create_optimizer_and_scheduler(self):
  885. """ Create optimizer and lr scheduler
  886. We provide a default implementation, if you want to customize your own optimizer
  887. and lr scheduler, you can either pass a tuple through trainer init function or
  888. subclass this class and override this method.
  889. """
  890. optimizer, lr_scheduler = self.optimizers
  891. if optimizer is None:
  892. optimizer_cfg = deepcopy(self.cfg.train.get('optimizer', None))
  893. else:
  894. optimizer_cfg = None
  895. optim_options = {}
  896. if optimizer_cfg is not None:
  897. optim_options = optimizer_cfg.pop('options', {})
  898. optimizer = self.build_optimizer(cfg=optimizer_cfg)
  899. if lr_scheduler is None:
  900. lr_scheduler_cfg = deepcopy(
  901. self.cfg.train.get('lr_scheduler', None))
  902. else:
  903. lr_scheduler_cfg = None
  904. lr_options = {}
  905. if lr_scheduler_cfg is not None:
  906. assert optimizer is not None
  907. lr_options = lr_scheduler_cfg.pop('options', {})
  908. lr_scheduler = self.build_lr_scheduler(
  909. cfg=lr_scheduler_cfg, default_args={'optimizer': optimizer})
  910. self.optimizer = optimizer
  911. self.lr_scheduler = lr_scheduler
  912. return self.optimizer, self.lr_scheduler, optim_options, lr_options
  913. def register_optimizers_hook(self):
  914. """ Register optimizer hook and lr scheduler hook.
  915. """
  916. _, lr_scheduler, optim_options, lr_options = self.create_optimizer_and_scheduler(
  917. )
  918. optim_hook = self.cfg.train.get('optimizer_hook', {})
  919. lr_hook = self.cfg.train.get('lr_scheduler_hook', {})
  920. # adapt to `ReduceLROnPlateau`
  921. from torch.optim.lr_scheduler import ReduceLROnPlateau
  922. if isinstance(lr_scheduler, ReduceLROnPlateau) and not lr_hook:
  923. plateau_cfg = {
  924. 'train': {
  925. 'lr_scheduler_hook': {
  926. 'type': 'PlateauLrSchedulerHook',
  927. 'metric_key':
  928. 'Metric Key used for PlateauLrSchedulerHook'
  929. }
  930. }
  931. }
  932. plateau_cfg = json.dumps(
  933. plateau_cfg, sort_keys=False, indent=4, separators=(',', ':'))
  934. raise ValueError(
  935. 'Must add `lr_scheduler_hook` to configuration for `ReduceLROnPlateau` lr scheduler as follows:'
  936. + '\n' + plateau_cfg)
  937. def _fit_to_old_keys():
  938. """This function used to fit `optimizer_hook` key and `lr_scheduler_hook` key for easycv configs.
  939. The logic is:
  940. If the optimizer_hook is provided and it's not TorchAMPOptimizerHook or ApexAMPOptimizerHook,
  941. (which means the hook is a complete one for optimization, which does not need the OptimizerHook),
  942. The OptimizerHook will not be registered, or else the OptimizerHook will be registered.
  943. Same logic to the LrSchedulerHook, the only difference is the condition of lr_scheduler_hook is
  944. PlateauLrSchedulerHook.
  945. If TorchAMPOptimizerHook or ApexAMPOptimizerHook is provided, self.use_fp16 will be set to False
  946. in case of the duplication of registration.
  947. """
  948. if lr_hook:
  949. self.register_hook_from_cfg([lr_hook])
  950. _lr_options = None
  951. if not lr_hook or lr_hook.get('type') == 'PlateauLrSchedulerHook':
  952. lr_hook.pop('type', None)
  953. _lr_options = {**lr_options, **lr_hook}
  954. if optim_hook:
  955. self.register_hook_from_cfg([optim_hook])
  956. _optim_options = None
  957. if optim_hook.get('type') in ('TorchAMPOptimizerHook',
  958. 'ApexAMPOptimizerHook'):
  959. self.use_fp16 = False
  960. if not optim_hook or optim_hook.get('type') in (
  961. 'TorchAMPOptimizerHook', 'ApexAMPOptimizerHook'):
  962. optim_hook.pop('type', None)
  963. _optim_options = {**optim_options, **optim_hook}
  964. return _optim_options, _lr_options
  965. optim_options, lr_options = _fit_to_old_keys()
  966. if optim_options is not None:
  967. self.register_hook_from_cfg(
  968. [dict(type='OptimizerHook', **optim_options)])
  969. if lr_options is not None:
  970. self.register_hook_from_cfg(
  971. [dict(type='LrSchedulerHook', **lr_options)])
  972. if self.use_fp16:
  973. self.register_hook_from_cfg(
  974. [dict(type='TorchAMPOptimizerHook', **optim_options)])
  975. def _build_dataloader_with_dataset(self,
  976. dataset: Dataset,
  977. batch_size_per_gpu: int,
  978. workers_per_gpu: int,
  979. dist: bool = False,
  980. shuffle: bool = True,
  981. seed: int = 0,
  982. persistent_workers=False,
  983. **kwargs) -> DataLoader:
  984. """Build dataloader using input dataset and cfg. Used by `EpochBasedTrainer.train()`
  985. and `EpochBasedTrainer.evaluate()`.
  986. In distributed training, each GPU/process has a dataloader.
  987. In non-distributed training, there is only one dataloader for all GPUs.
  988. Args:
  989. dataset (Dataset): A PyTorch dataset.
  990. batch_size_per_gpu (int): Number of training samples on each GPU, i.e.,
  991. batch size of each GPU.
  992. workers_per_gpu (int): How many subprocesses to use for data loading
  993. for each GPU.
  994. dist (bool): Distributed training/test or not. Default: True.
  995. shuffle (bool): Whether to shuffle the data at every epoch.
  996. Default: True.
  997. seed (int, Optional): Seed to be used. Default: 0.
  998. runner_type (str): Type of runner. Default: `EpochBasedRunner`
  999. persistent_workers (bool): If True, the data loader will not shutdown
  1000. the worker processes after a dataset has been consumed once.
  1001. This allows to maintain the workers `Dataset` instances alive.
  1002. This argument is only valid when PyTorch>=1.7.0. Default: False.
  1003. kwargs: any keyword argument to be used to initialize DataLoader
  1004. Returns:
  1005. DataLoader: A PyTorch dataloader.
  1006. """
  1007. rank = 0
  1008. world_size = 1
  1009. if self.is_dp_group_available():
  1010. rank = torch.distributed.get_rank(self.dp_group)
  1011. world_size = torch.distributed.get_world_size(self.dp_group)
  1012. if dist:
  1013. # When model is :obj:`DistributedDataParallel`,
  1014. # `batch_size` of :obj:`dataloader` is the
  1015. # number of training samples on each GPU.
  1016. batch_size = batch_size_per_gpu
  1017. num_workers = workers_per_gpu
  1018. else:
  1019. batch_size = batch_size_per_gpu
  1020. num_workers = workers_per_gpu
  1021. sampler = kwargs.pop('sampler', None)
  1022. if sampler is None:
  1023. if dist and not isinstance(dataset,
  1024. torch.utils.data.IterableDataset):
  1025. sampler = DistributedSampler(
  1026. dataset,
  1027. num_replicas=world_size,
  1028. rank=rank,
  1029. shuffle=shuffle)
  1030. else:
  1031. sampler = None
  1032. if not isinstance(dataset, torch.utils.data.IterableDataset):
  1033. kwargs['shuffle'] = shuffle
  1034. batch_sampler = None
  1035. init_fn = partial(
  1036. worker_init_fn, num_workers=num_workers, rank=rank,
  1037. seed=seed) if seed is not None else None
  1038. if LooseVersion(torch.__version__) >= LooseVersion('1.7.0'):
  1039. kwargs['persistent_workers'] = persistent_workers
  1040. elif persistent_workers is True:
  1041. self.logger.warning(
  1042. 'persistent_workers is invalid because your pytorch '
  1043. 'version is lower than 1.7.0')
  1044. data_loader = DataLoader(
  1045. dataset,
  1046. batch_size=batch_size,
  1047. sampler=sampler,
  1048. num_workers=num_workers,
  1049. batch_sampler=batch_sampler,
  1050. pin_memory=kwargs.pop('pin_memory', False),
  1051. worker_init_fn=init_fn,
  1052. **kwargs)
  1053. return data_loader
  1054. def train_loop(self, data_loader):
  1055. """ Training loop used by `EpochBasedTrainer.train()`
  1056. """
  1057. self.invoke_hook(TrainerStages.before_run)
  1058. self.model.train()
  1059. for _ in range(self._epoch, self._max_epochs):
  1060. self.invoke_hook(TrainerStages.before_train_epoch)
  1061. for i, data_batch in enumerate(data_loader):
  1062. if i < self.inner_iter:
  1063. # inner_iter may be read out from the checkpoint file, so skip the trained iters in the epoch.
  1064. continue
  1065. data_batch = to_device(data_batch, self.device)
  1066. self.data_batch = data_batch
  1067. self._inner_iter = i
  1068. self.invoke_hook(TrainerStages.before_train_iter)
  1069. self.train_step(self.model, data_batch)
  1070. self.invoke_hook(TrainerStages.after_train_iter)
  1071. # Value changed after the hooks are invoked, do not move them above the invoke_hook code.
  1072. del self.data_batch
  1073. self._iter += 1
  1074. self._mode = ModeKeys.TRAIN
  1075. if i + 1 >= self.iters_per_epoch:
  1076. break
  1077. self.invoke_hook(TrainerStages.after_train_epoch)
  1078. # Value changed after the hooks are invoked, do not move them above the invoke_hook code.
  1079. self._inner_iter = 0
  1080. self._epoch += 1
  1081. if self._stop_training:
  1082. break
  1083. self.invoke_hook(TrainerStages.after_run)
  1084. def evaluation_step(self, data):
  1085. """Perform a training step on a batch of inputs.
  1086. Subclass and override to inject custom behavior.
  1087. """
  1088. self.model.eval()
  1089. receive_dict_inputs = func_receive_dict_inputs(
  1090. self.unwrap_module(self.model).forward)
  1091. with torch.no_grad():
  1092. if isinstance(data, Mapping) and not receive_dict_inputs:
  1093. result = self.model.forward(**data)
  1094. else:
  1095. result = self.model.forward(data)
  1096. return result
  1097. def evaluation_loop(self, data_loader, metric_classes):
  1098. """ Evaluation loop used by `EpochBasedTrainer.evaluate()`.
  1099. """
  1100. vis_closure = None
  1101. if hasattr(self.cfg.evaluation, 'visualization'):
  1102. vis_cfg = self.cfg.evaluation.visualization
  1103. vis_closure = partial(
  1104. self.visualization, dataset=self.eval_dataset, **vis_cfg)
  1105. self.invoke_hook(TrainerStages.before_val)
  1106. if self._dist:
  1107. from modelscope.trainers.utils.inference import multi_gpu_test
  1108. # list of batched result and data samples
  1109. metric_values = multi_gpu_test(
  1110. self,
  1111. data_loader,
  1112. device=self.device,
  1113. metric_classes=metric_classes,
  1114. vis_closure=vis_closure,
  1115. tmpdir=self.cfg.evaluation.get('cache_dir', None),
  1116. gpu_collect=self.cfg.evaluation.get('gpu_collect', False),
  1117. data_loader_iters_per_gpu=self._eval_iters_per_epoch)
  1118. else:
  1119. from modelscope.trainers.utils.inference import single_gpu_test
  1120. metric_values = single_gpu_test(
  1121. self,
  1122. data_loader,
  1123. device=self.device,
  1124. metric_classes=metric_classes,
  1125. vis_closure=vis_closure,
  1126. data_loader_iters=self._eval_iters_per_epoch)
  1127. self.invoke_hook(TrainerStages.after_val)
  1128. return metric_values
  1129. def visualization(self, batch_result, dataset, **kwargs):
  1130. """ visualization function for evaluation results.
  1131. Examples:
  1132. >>> # draw list of images as numpy array
  1133. >>> images = draw_images(num_of_visualization)
  1134. >>> # set displayed name for each image
  1135. >>> filenames = get_image_display_names()
  1136. >>> vis_results = {'images': images, 'filenames' : filenames}
  1137. >>> # visualization results will be displayed in group named eva_vis
  1138. >>> self.visualization_buffer.output['eval_vis'] = vis_results
  1139. Args:
  1140. results (list(dict)): a list of result dict.
  1141. dataset (Dataset): torch dataset object to access original data.
  1142. """
  1143. # TODO @wenmeng.zwm add visualization support for cv evaluation
  1144. raise NotImplementedError(
  1145. 'visualization for evaluation will be supported in the future')
  1146. def register_hook(self, hook: Hook) -> None:
  1147. """Register a hook into the hook list.
  1148. The hook will be inserted into a priority queue, with the specified
  1149. priority (See :class:`Priority` for details of priorities).
  1150. For hooks with the same priority, they will be triggered in the same
  1151. order as they are registered.
  1152. Args:
  1153. hook (:obj:`Hook`): The hook to be registered.
  1154. """
  1155. # insert the hook to a sorted list
  1156. inserted = False
  1157. for i in range(len(self._hooks) - 1, -1, -1):
  1158. p = hook.PRIORITY if hasattr(hook, 'PRIORITY') else Priority.NORMAL
  1159. p_i = self._hooks[i].PRIORITY if hasattr(
  1160. self._hooks[i], 'PRIORITY') else Priority.NORMAL
  1161. if get_priority(p) > get_priority(p_i):
  1162. self._hooks.insert(i + 1, hook)
  1163. inserted = True
  1164. break
  1165. if not inserted:
  1166. self._hooks.insert(0, hook)
  1167. def register_hook_from_cfg(self, hook_cfg: List) -> List:
  1168. """Register a hook from its cfg.
  1169. Args:
  1170. hook_cfg (dict): Hook config. It should have at least keys 'type'
  1171. and 'priority' indicating its type and priority.
  1172. Note:
  1173. The specific hook class to register should not use 'type' and
  1174. 'priority' arguments during initialization.
  1175. Returns:
  1176. A list of instances of registered hooks.
  1177. """
  1178. hook_cfg = hook_cfg.copy()
  1179. assert isinstance(hook_cfg, list)
  1180. hooks = []
  1181. for cfg_i in hook_cfg:
  1182. hook = build_from_cfg(cfg_i, HOOKS)
  1183. self.register_hook(hook)
  1184. hooks.append(hook)
  1185. return hooks
  1186. def register_processors(self):
  1187. """Register processors to hooks
  1188. """
  1189. for hook in self.hooks:
  1190. if hasattr(hook, 'register_processor'):
  1191. hook.register_processor(self)
  1192. def get_hook(self, cls):
  1193. return [h for h in self._hooks if h.__class__ == cls]
  1194. def invoke_hook(self, fn_name: str) -> None:
  1195. """Call all hooks.
  1196. Args:
  1197. fn_name (str): The function name in each hook to be called, such as
  1198. "before_train_epoch".
  1199. """
  1200. for hook in self._hooks:
  1201. if hasattr(hook, fn_name):
  1202. getattr(hook, fn_name)(self)
  1203. def print_cfg(self):
  1204. if is_master():
  1205. cfg = deepcopy(self.cfg)
  1206. cfg.train.work_dir = self.work_dir
  1207. self.logger.info(
  1208. '==========================Training Config Start=========================='
  1209. )
  1210. self.logger.info(
  1211. json.dumps(cfg._cfg_dict, indent=4, cls=JSONIteratorEncoder))
  1212. self.logger.info(
  1213. '===========================Training Config End==========================='
  1214. )
  1215. def print_hook_info(self):
  1216. if is_master() and not getattr(self, '_hook_info_printed', False):
  1217. self.logger.info(self.get_hook_info())
  1218. self._hook_info_printed = True
  1219. def get_hook_info(self) -> str:
  1220. # Get hooks info in each stage
  1221. stage_hook_map: Dict[str, list] = {stage: [] for stage in Hook.stages}
  1222. for hook in self.hooks:
  1223. try:
  1224. priority = Priority(hook.PRIORITY).name # type: ignore
  1225. except Exception:
  1226. priority = Priority.NORMAL # type: ignore
  1227. classname = hook.__class__.__name__
  1228. hook_info = f'({priority:<12}) {classname:<35}'
  1229. if hasattr(hook, 'get_triggered_stages'):
  1230. for trigger_stage in hook.get_triggered_stages():
  1231. stage_hook_map[trigger_stage].append(hook_info)
  1232. stage_hook_infos = []
  1233. for stage in Hook.stages:
  1234. hook_infos = stage_hook_map[stage]
  1235. if len(hook_infos) > 0:
  1236. info = f'Stage: {stage}:\n '
  1237. info += '\n '.join(hook_infos)
  1238. info += '\n -------------------- '
  1239. stage_hook_infos.append(info)
  1240. stage_hook_infos = '\n'.join(stage_hook_infos)
  1241. return stage_hook_infos
  1242. def worker_init_fn(worker_id, num_workers, rank, seed):
  1243. # The seed of each worker equals to
  1244. # num_worker * rank + worker_id + user_seed
  1245. worker_seed = num_workers * rank + worker_id + seed
  1246. set_random_seed(worker_seed)