image_classifition_trainer.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516
  1. # Part of the implementation is borrowed and modified from mmclassification,
  2. # publicly available at https://github.com/open-mmlab/mmclassification
  3. import copy
  4. import os
  5. import os.path as osp
  6. import time
  7. from typing import Callable, Dict, Optional, Tuple, Union
  8. import numpy as np
  9. import torch
  10. from torch import nn
  11. from torch.utils.data import Dataset
  12. from modelscope.hub.snapshot_download import snapshot_download
  13. from modelscope.metainfo import Trainers
  14. from modelscope.models.base import TorchModel
  15. from modelscope.msdatasets.ms_dataset import MsDataset
  16. from modelscope.preprocessors.base import Preprocessor
  17. from modelscope.trainers.base import BaseTrainer
  18. from modelscope.trainers.builder import TRAINERS
  19. from modelscope.utils.constant import DEFAULT_MODEL_REVISION, Invoke, ModelFile
  20. from modelscope.utils.logger import get_logger
  21. def train_model(model,
  22. dataset,
  23. cfg,
  24. distributed=False,
  25. val_dataset=None,
  26. timestamp=None,
  27. device=None,
  28. meta=None):
  29. import torch
  30. import warnings
  31. from mmcv.runner import (DistSamplerSeedHook, Fp16OptimizerHook,
  32. build_optimizer, build_runner, get_dist_info)
  33. from mmcls.core import DistEvalHook, DistOptimizerHook, EvalHook
  34. from mmcls.datasets import build_dataloader
  35. from mmcls.utils import (wrap_distributed_model,
  36. wrap_non_distributed_model)
  37. from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
  38. logger = get_logger()
  39. # prepare data loaders
  40. dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
  41. sampler_cfg = cfg.train.get('sampler', None)
  42. data_loaders = [
  43. build_dataloader(
  44. ds,
  45. cfg.train.dataloader.batch_size_per_gpu,
  46. cfg.train.dataloader.workers_per_gpu,
  47. # cfg.gpus will be ignored if distributed
  48. num_gpus=len(cfg.gpu_ids),
  49. dist=distributed,
  50. round_up=True,
  51. seed=cfg.seed,
  52. sampler_cfg=sampler_cfg) for ds in dataset
  53. ]
  54. # put model on gpus
  55. if distributed:
  56. find_unused_parameters = cfg.get('find_unused_parameters', False)
  57. # Sets the `find_unused_parameters` parameter in
  58. # torch.nn.parallel.DistributedDataParallel
  59. model = MMDistributedDataParallel(
  60. model.cuda(),
  61. device_ids=[torch.cuda.current_device()],
  62. broadcast_buffers=False,
  63. find_unused_parameters=find_unused_parameters)
  64. else:
  65. if device == 'cpu':
  66. logger.warning(
  67. 'The argument `device` is deprecated. To use cpu to train, '
  68. 'please refers to https://mmclassification.readthedocs.io/en'
  69. '/latest/getting_started.html#train-a-model')
  70. model = model.cpu()
  71. else:
  72. model = MMDataParallel(model, device_ids=cfg.gpu_ids)
  73. if not model.device_ids:
  74. from mmcv import __version__, digit_version
  75. assert digit_version(__version__) >= (1, 4, 4), \
  76. 'To train with CPU, please confirm your mmcv version ' \
  77. 'is not lower than v1.4.4'
  78. # build runner
  79. optimizer = build_optimizer(model, cfg.train.optimizer)
  80. if cfg.train.get('runner') is None:
  81. cfg.train.runner = {
  82. 'type': 'EpochBasedRunner',
  83. 'max_epochs': cfg.train.max_epochs
  84. }
  85. logger.warning(
  86. 'config is now expected to have a `runner` section, '
  87. 'please set `runner` in your config.', UserWarning)
  88. runner = build_runner(
  89. cfg.train.runner,
  90. default_args=dict(
  91. model=model,
  92. batch_processor=None,
  93. optimizer=optimizer,
  94. work_dir=cfg.work_dir,
  95. logger=logger,
  96. meta=meta))
  97. # an ugly walkaround to make the .log and .log.json filenames the same
  98. runner.timestamp = timestamp
  99. # fp16 setting
  100. fp16_cfg = cfg.get('fp16', None)
  101. if fp16_cfg is not None:
  102. optimizer_config = Fp16OptimizerHook(
  103. **cfg.train.optimizer_config, **fp16_cfg, distributed=distributed)
  104. elif distributed and 'type' not in cfg.train.optimizer_config:
  105. optimizer_config = DistOptimizerHook(**cfg.train.optimizer_config)
  106. else:
  107. optimizer_config = cfg.train.optimizer_config
  108. # register hooks
  109. runner.register_training_hooks(
  110. cfg.train.lr_config,
  111. optimizer_config,
  112. cfg.train.checkpoint_config,
  113. cfg.train.log_config,
  114. cfg.train.get('momentum_config', None),
  115. custom_hooks_config=cfg.train.get('custom_hooks', None))
  116. if distributed and cfg.train.runner['type'] == 'EpochBasedRunner':
  117. runner.register_hook(DistSamplerSeedHook())
  118. # register eval hooks
  119. if val_dataset is not None:
  120. val_dataloader = build_dataloader(
  121. val_dataset,
  122. samples_per_gpu=cfg.evaluation.dataloader.batch_size_per_gpu,
  123. workers_per_gpu=cfg.evaluation.dataloader.workers_per_gpu,
  124. dist=distributed,
  125. shuffle=False,
  126. round_up=True)
  127. eval_cfg = cfg.train.get('evaluation', {})
  128. eval_cfg['by_epoch'] = cfg.train.runner['type'] != 'IterBasedRunner'
  129. eval_hook = DistEvalHook if distributed else EvalHook
  130. # `EvalHook` needs to be executed after `IterTimerHook`.
  131. # Otherwise, it will cause a bug if use `IterBasedRunner`.
  132. # Refers to https://github.com/open-mmlab/mmcv/issues/1261
  133. runner.register_hook(
  134. eval_hook(val_dataloader, **eval_cfg), priority='LOW')
  135. if cfg.train.resume_from:
  136. runner.resume(cfg.train.resume_from, map_location='cpu')
  137. elif cfg.train.load_from:
  138. runner.load_checkpoint(cfg.train.load_from)
  139. cfg.train.workflow = [tuple(flow) for flow in cfg.train.workflow]
  140. runner.run(data_loaders, cfg.train.workflow)
  141. @TRAINERS.register_module(module_name=Trainers.image_classification)
  142. class ImageClassifitionTrainer(BaseTrainer):
  143. def __init__(
  144. self,
  145. model: Optional[Union[TorchModel, nn.Module, str]] = None,
  146. cfg_file: Optional[str] = None,
  147. arg_parse_fn: Optional[Callable] = None,
  148. data_collator: Optional[Union[Callable, Dict[str,
  149. Callable]]] = None,
  150. train_dataset: Optional[Union[MsDataset, Dataset]] = None,
  151. eval_dataset: Optional[Union[MsDataset, Dataset]] = None,
  152. preprocessor: Optional[Union[Preprocessor,
  153. Dict[str, Preprocessor]]] = None,
  154. optimizers: Tuple[torch.optim.Optimizer,
  155. torch.optim.lr_scheduler._LRScheduler] = (None,
  156. None),
  157. model_revision: Optional[str] = DEFAULT_MODEL_REVISION,
  158. seed: int = 0,
  159. cfg_modify_fn: Optional[Callable] = None,
  160. **kwargs):
  161. """ High-level finetune api for Image Classifition.
  162. Args:
  163. model: model id
  164. model_version: model version, default is None.
  165. cfg_modify_fn: An input fn which is used to modify the cfg read out of the file.
  166. """
  167. import torch
  168. import mmcv
  169. from modelscope.models.cv.image_classification.utils import get_ms_dataset_root, get_classes
  170. from mmcls.models import build_classifier
  171. from mmcv.runner import get_dist_info, init_dist
  172. from mmcls.apis import set_random_seed
  173. from mmcls.utils import collect_env
  174. from mmcv.utils import get_logger as mmcv_get_logger
  175. import modelscope.models.cv.image_classification.backbones
  176. self._seed = seed
  177. set_random_seed(self._seed)
  178. if isinstance(model, str):
  179. self.model_dir = self.get_or_download_model_dir(
  180. model, model_revision=model_revision)
  181. if cfg_file is None:
  182. cfg_file = os.path.join(self.model_dir,
  183. ModelFile.CONFIGURATION)
  184. else:
  185. assert cfg_file is not None, 'Config file should not be None if model is not from pretrained!'
  186. self.model_dir = os.path.dirname(cfg_file)
  187. super().__init__(cfg_file, arg_parse_fn)
  188. cfg = self.cfg
  189. if 'work_dir' in kwargs:
  190. self.work_dir = kwargs['work_dir']
  191. else:
  192. self.work_dir = self.cfg.train.get('work_dir', './work_dir')
  193. mmcv.mkdir_or_exist(osp.abspath(self.work_dir))
  194. cfg.work_dir = self.work_dir
  195. # evaluate config seting
  196. self.eval_checkpoint_path = os.path.join(self.model_dir,
  197. ModelFile.TORCH_MODEL_FILE)
  198. # train config seting
  199. if 'resume_from' in kwargs:
  200. cfg.train.resume_from = kwargs['resume_from']
  201. else:
  202. cfg.train.resume_from = cfg.train.get('resume_from', None)
  203. if 'load_from' in kwargs:
  204. cfg.train.load_from = kwargs['load_from']
  205. else:
  206. if cfg.train.get('resume_from', None) is None:
  207. cfg.train.load_from = os.path.join(self.model_dir,
  208. ModelFile.TORCH_MODEL_FILE)
  209. if 'device' in kwargs:
  210. cfg.device = kwargs['device']
  211. else:
  212. cfg.device = cfg.get('device', 'cuda')
  213. if 'gpu_ids' in kwargs:
  214. cfg.gpu_ids = kwargs['gpu_ids'][0:1]
  215. else:
  216. cfg.gpu_ids = [0]
  217. if 'fp16' in kwargs:
  218. cfg.fp16 = None if kwargs['fp16'] is None else kwargs['fp16']
  219. else:
  220. cfg.fp16 = None
  221. # no_validate=True will not evaluate checkpoint during training
  222. cfg.no_validate = kwargs.get('no_validate', False)
  223. if cfg_modify_fn is not None:
  224. cfg = cfg_modify_fn(cfg)
  225. if 'max_epochs' not in kwargs:
  226. assert hasattr(
  227. self.cfg.train,
  228. 'max_epochs'), 'max_epochs is missing in configuration file'
  229. self.max_epochs = self.cfg.train.max_epochs
  230. else:
  231. self.max_epochs = kwargs['max_epochs']
  232. cfg.train.max_epochs = self.max_epochs
  233. if cfg.train.get('runner', None) is not None:
  234. cfg.train.runner.max_epochs = self.max_epochs
  235. if 'launcher' in kwargs:
  236. distributed = True
  237. dist_params = kwargs['dist_params'] \
  238. if 'dist_params' in kwargs else {'backend': 'nccl'}
  239. init_dist(kwargs['launcher'], **dist_params)
  240. # re-set gpu_ids with distributed training mode
  241. _, world_size = get_dist_info()
  242. cfg.gpu_ids = list(range(world_size))
  243. else:
  244. distributed = False
  245. # init the logger before other steps
  246. mmcv_get_logger('modelscope') # set name of mmcv logger
  247. timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
  248. log_file = osp.join(self.work_dir, f'{timestamp}.log')
  249. logger = get_logger(log_file=log_file)
  250. # init the meta dict to record some important information such as
  251. # environment info and seed, which will be logged
  252. meta = dict()
  253. # log env info
  254. env_info_dict = collect_env()
  255. env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()])
  256. dash_line = '-' * 60 + '\n'
  257. logger.info('Environment info:\n' + dash_line + env_info + '\n'
  258. + dash_line)
  259. meta['env_info'] = env_info
  260. meta['config'] = cfg.pretty_text
  261. # log some basic info
  262. logger.info(f'Distributed training: {distributed}')
  263. logger.info(f'Config:\n{cfg.pretty_text}')
  264. # set random seeds
  265. cfg.seed = self._seed
  266. _deterministic = kwargs.get('deterministic', False)
  267. logger.info(f'Set random seed to {cfg.seed}, '
  268. f'deterministic: {_deterministic}')
  269. set_random_seed(cfg.seed, deterministic=_deterministic)
  270. meta['seed'] = cfg.seed
  271. meta['exp_name'] = osp.basename(cfg_file)
  272. # dataset
  273. self.train_dataset = train_dataset
  274. self.eval_dataset = eval_dataset
  275. # set img_prefix for image data path in csv files.
  276. if cfg.dataset.get('data_prefix', None) is None:
  277. self.data_prefix = ''
  278. else:
  279. self.data_prefix = cfg.dataset.data_prefix
  280. # model
  281. model = build_classifier(self.cfg.model.mm_model)
  282. model.init_weights()
  283. self.cfg = cfg
  284. self.device = cfg.device
  285. self.cfg_file = cfg_file
  286. self.model = model
  287. self.distributed = distributed
  288. self.timestamp = timestamp
  289. self.meta = meta
  290. self.logger = logger
  291. def train(self, *args, **kwargs):
  292. from mmcls import __version__
  293. from modelscope.models.cv.image_classification.utils import get_ms_dataset_root, MmDataset, preprocess_transform
  294. from mmcls.utils import setup_multi_processes
  295. if self.train_dataset is None:
  296. raise ValueError(
  297. "Not found train dataset, please set the 'train_dataset' parameter!"
  298. )
  299. self.cfg.model.mm_model.pretrained = None
  300. # dump config
  301. self.cfg.dump(osp.join(self.work_dir, osp.basename(self.cfg_file)))
  302. # build the dataloader
  303. if self.cfg.dataset.classes is None:
  304. data_root = get_ms_dataset_root(self.train_dataset)
  305. classname_path = osp.join(data_root, 'classname.txt')
  306. classes = classname_path if osp.exists(classname_path) else None
  307. else:
  308. classes = self.cfg.dataset.classes
  309. datasets = [
  310. MmDataset(
  311. self.train_dataset,
  312. pipeline=self.cfg.preprocessor.train,
  313. classes=classes,
  314. data_prefix=self.data_prefix)
  315. ]
  316. if len(self.cfg.train.workflow) == 2:
  317. if self.eval_dataset is None:
  318. raise ValueError(
  319. "Not found evaluate dataset, please set the 'eval_dataset' parameter!"
  320. )
  321. val_data_pipeline = self.cfg.preprocessor.train
  322. val_dataset = MmDataset(
  323. self.eval_dataset,
  324. pipeline=val_data_pipeline,
  325. classes=classes,
  326. data_prefix=self.data_prefix)
  327. datasets.append(val_dataset)
  328. # save mmcls version, config file content and class names in
  329. # checkpoints as meta data
  330. self.meta.update(
  331. dict(
  332. mmcls_version=__version__,
  333. config=self.cfg.pretty_text,
  334. CLASSES=datasets[0].CLASSES))
  335. val_dataset = None
  336. if not self.cfg.no_validate:
  337. val_dataset = MmDataset(
  338. self.eval_dataset,
  339. pipeline=preprocess_transform(self.cfg.preprocessor.val),
  340. classes=classes,
  341. data_prefix=self.data_prefix)
  342. # add an attribute for visualization convenience
  343. train_model(
  344. self.model,
  345. datasets,
  346. self.cfg,
  347. distributed=self.distributed,
  348. val_dataset=val_dataset,
  349. timestamp=self.timestamp,
  350. device='cpu' if self.device == 'cpu' else 'cuda',
  351. meta=self.meta)
  352. def evaluate(self,
  353. checkpoint_path: str = None,
  354. *args,
  355. **kwargs) -> Dict[str, float]:
  356. import warnings
  357. import torch
  358. from modelscope.models.cv.image_classification.utils import (
  359. get_ms_dataset_root, MmDataset, preprocess_transform,
  360. get_trained_checkpoints_name)
  361. from mmcls.datasets import build_dataloader
  362. from mmcv.runner import get_dist_info, load_checkpoint, wrap_fp16_model
  363. from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
  364. from mmcls.apis import multi_gpu_test, single_gpu_test
  365. from mmcls.utils import setup_multi_processes
  366. if self.eval_dataset is None:
  367. raise ValueError(
  368. "Not found evaluate dataset, please set the 'eval_dataset' parameter!"
  369. )
  370. self.cfg.model.mm_model.pretrained = None
  371. # build the dataloader
  372. if self.cfg.dataset.classes is None:
  373. data_root = get_ms_dataset_root(self.eval_dataset)
  374. classname_path = osp.join(data_root, 'classname.txt')
  375. classes = classname_path if osp.exists(classname_path) else None
  376. else:
  377. classes = self.cfg.dataset.classes
  378. dataset = MmDataset(
  379. self.eval_dataset,
  380. pipeline=preprocess_transform(self.cfg.preprocessor.val),
  381. classes=classes,
  382. data_prefix=self.data_prefix)
  383. # the extra round_up data will be removed during gpu/cpu collect
  384. data_loader = build_dataloader(
  385. dataset,
  386. samples_per_gpu=self.cfg.evaluation.dataloader.batch_size_per_gpu,
  387. workers_per_gpu=self.cfg.evaluation.dataloader.workers_per_gpu,
  388. dist=self.distributed,
  389. shuffle=False,
  390. round_up=True)
  391. model = copy.deepcopy(self.model)
  392. fp16_cfg = self.cfg.get('fp16', None)
  393. if fp16_cfg is not None:
  394. wrap_fp16_model(model)
  395. if checkpoint_path is None:
  396. trained_checkpoints = get_trained_checkpoints_name(self.work_dir)
  397. if trained_checkpoints is not None:
  398. checkpoint = load_checkpoint(
  399. model,
  400. os.path.join(self.work_dir, trained_checkpoints),
  401. map_location='cpu')
  402. else:
  403. checkpoint = load_checkpoint(
  404. model, self.eval_checkpoint_path, map_location='cpu')
  405. else:
  406. checkpoint = load_checkpoint(
  407. model, checkpoint_path, map_location='cpu')
  408. if 'CLASSES' in checkpoint.get('meta', {}):
  409. CLASSES = checkpoint['meta']['CLASSES']
  410. else:
  411. from mmcls.datasets import ImageNet
  412. self.logger.warning(
  413. 'Class names are not saved in the checkpoint\'s '
  414. 'meta data, use imagenet by default.')
  415. CLASSES = ImageNet.CLASSES
  416. if not self.distributed:
  417. if self.device == 'cpu':
  418. model = model.cpu()
  419. else:
  420. model = MMDataParallel(model, device_ids=self.cfg.gpu_ids)
  421. if not model.device_ids:
  422. assert mmcv.digit_version(mmcv.__version__) >= (1, 4, 4), \
  423. 'To test with CPU, please confirm your mmcv version ' \
  424. 'is not lower than v1.4.4'
  425. model.CLASSES = CLASSES
  426. show_kwargs = {}
  427. outputs = single_gpu_test(model, data_loader, False, None,
  428. **show_kwargs)
  429. else:
  430. model = MMDistributedDataParallel(
  431. model.cuda(),
  432. device_ids=[torch.cuda.current_device()],
  433. broadcast_buffers=False)
  434. outputs = multi_gpu_test(model, data_loader, None, True)
  435. rank, _ = get_dist_info()
  436. if rank == 0:
  437. results = {}
  438. logger = get_logger()
  439. metric_options = self.cfg.evaluation.get('metric_options', {})
  440. if 'topk' in metric_options.keys():
  441. metric_options['topk'] = tuple(metric_options['topk'])
  442. # mmcls will set the default value of topk to (1, 5) which
  443. # will cause error when number of classes less then 5.
  444. # set topk as (1,) if len(CLASSES) < 5:
  445. elif len(CLASSES) < 5:
  446. metric_options['topk'] = (1, )
  447. if self.cfg.evaluation.metrics:
  448. eval_results = dataset.evaluate(
  449. results=outputs,
  450. metric=self.cfg.evaluation.metrics,
  451. metric_options=metric_options,
  452. logger=logger)
  453. results.update(eval_results)
  454. return results
  455. return None