action_detection_trainer.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import os.path as osp
  4. from typing import Callable, Dict, Optional
  5. import torch
  6. from detectron2.checkpoint import DetectionCheckpointer
  7. from detectron2.data import (build_detection_test_loader,
  8. build_detection_train_loader)
  9. from detectron2.engine import SimpleTrainer, hooks, launch
  10. from detectron2.engine.defaults import create_ddp_model, default_writers
  11. from detectron2.evaluation import inference_on_dataset, print_csv_format
  12. from detectron2.solver import LRMultiplier, WarmupParamScheduler
  13. from detectron2.solver.build import get_default_optimizer_params
  14. from detectron2.utils import comm
  15. from detectron2.utils.file_io import PathManager
  16. from detectron2.utils.logger import setup_logger
  17. from fvcore.common.param_scheduler import CosineParamScheduler
  18. from modelscope.hub.check_model import check_local_model_is_latest
  19. from modelscope.hub.snapshot_download import snapshot_download
  20. from modelscope.metainfo import Trainers
  21. from modelscope.metrics.action_detection_evaluator import DetEvaluator
  22. from modelscope.models.cv.action_detection.modules.action_detection_pytorch import \
  23. build_action_detection_model
  24. from modelscope.preprocessors.cv.action_detection_mapper import VideoDetMapper
  25. from modelscope.trainers.base import BaseTrainer
  26. from modelscope.trainers.builder import TRAINERS
  27. from modelscope.utils.constant import Invoke, ModelFile, Tasks
  28. @TRAINERS.register_module(module_name=Trainers.action_detection)
  29. class ActionDetectionTrainer(BaseTrainer):
  30. def __init__(self,
  31. model_id,
  32. train_dataset,
  33. test_dataset,
  34. cfg_file: str = None,
  35. cfg_modify_fn: Optional[Callable] = None,
  36. *args,
  37. **kwargs):
  38. model_cache_dir = self.get_or_download_model_dir(model_id)
  39. if cfg_file is None:
  40. cfg_file = os.path.join(model_cache_dir, ModelFile.CONFIGURATION)
  41. super().__init__(cfg_file)
  42. if cfg_modify_fn is not None:
  43. self.cfg = cfg_modify_fn(self.cfg)
  44. self.total_step = self.cfg.train.max_iter
  45. self.warmup_step = self.cfg.train.lr_scheduler['warmup_step']
  46. self.lr = self.cfg.train.optimizer.lr
  47. self.total_batch_size = max(
  48. 1, self.cfg.train.num_gpus
  49. ) * self.cfg.train.dataloader['batch_size_per_gpu']
  50. self.num_classes = len(self.cfg.train.classes_id_map)
  51. self.resume = kwargs.get('resume', False)
  52. self.train_dataset = train_dataset
  53. self.test_dataset = test_dataset
  54. self.pretrained_model = kwargs.get(
  55. 'pretrained_model',
  56. osp.join(model_cache_dir, ModelFile.TORCH_MODEL_FILE))
  57. def start(self, output_dir):
  58. if comm.is_main_process() and output_dir:
  59. PathManager.mkdirs(output_dir)
  60. self.cfg.dump(osp.join(output_dir, 'config.py'))
  61. rank = comm.get_rank()
  62. setup_logger(output_dir, distributed_rank=rank, name='fvcore')
  63. logger = setup_logger(output_dir, distributed_rank=rank)
  64. logger.info('Rank of current process: {}. World size: {}'.format(
  65. rank, comm.get_world_size()))
  66. def train(self, *args, **kwargs):
  67. if self.cfg.train.num_gpus <= 1:
  68. self.do_train()
  69. else:
  70. launch(
  71. self.do_train,
  72. self.cfg.train.num_gpus,
  73. 1,
  74. machine_rank=0,
  75. dist_url='auto',
  76. args=())
  77. def evaluate(self, checkpoint_path: str, *args,
  78. **kwargs) -> Dict[str, float]:
  79. if self.cfg.train.num_gpus <= 1:
  80. self.do_train(just_eval=True, checkpoint_path=checkpoint_path)
  81. else:
  82. launch(
  83. self.do_train,
  84. self.cfg.train.num_gpus,
  85. 1,
  86. machine_rank=0,
  87. dist_url='auto',
  88. args=(True, checkpoint_path))
  89. def do_train(
  90. self,
  91. just_eval=False,
  92. checkpoint_path=None,
  93. ):
  94. self.start(self.cfg.train.work_dir)
  95. model = build_action_detection_model(num_classes=self.num_classes)
  96. if self.cfg.train.num_gpus > 0:
  97. model.cuda()
  98. model = create_ddp_model(model, broadcast_buffers=False)
  99. if just_eval:
  100. checkpoint = DetectionCheckpointer(model)
  101. checkpoint.load(checkpoint_path)
  102. result = self.do_test(model)
  103. return result
  104. optim = torch.optim.AdamW(
  105. params=get_default_optimizer_params(model, base_lr=self.lr),
  106. lr=self.lr,
  107. weight_decay=0.1)
  108. lr_scheduler = LRMultiplier(
  109. optim,
  110. WarmupParamScheduler(
  111. CosineParamScheduler(1, 1e-3),
  112. warmup_factor=0,
  113. warmup_length=self.warmup_step / self.total_step),
  114. max_iter=self.total_step,
  115. )
  116. train_loader = build_detection_train_loader(
  117. self.train_dataset,
  118. mapper=VideoDetMapper(
  119. self.cfg.train.classes_id_map, is_train=True),
  120. total_batch_size=self.total_batch_size,
  121. num_workers=self.cfg.train.dataloader.workers_per_gpu)
  122. trainer = SimpleTrainer(model, train_loader, optim)
  123. checkpointer = DetectionCheckpointer(
  124. model, self.cfg.train.work_dir, trainer=trainer)
  125. trainer.register_hooks([
  126. hooks.IterationTimer(),
  127. hooks.LRScheduler(scheduler=lr_scheduler),
  128. hooks.PeriodicCheckpointer(
  129. checkpointer, period=self.cfg.train.checkpoint_interval)
  130. if comm.is_main_process() else None,
  131. hooks.EvalHook(
  132. eval_period=self.cfg.evaluation.interval,
  133. eval_function=lambda: self.do_test(model)),
  134. hooks.PeriodicWriter(
  135. default_writers(checkpointer.save_dir, self.total_step),
  136. period=20) if comm.is_main_process() else None,
  137. ])
  138. checkpointer.resume_or_load(self.pretrained_model, resume=False)
  139. if self.resume:
  140. checkpointer.resume_or_load(resume=self.resume)
  141. start_iter = trainer.iter + 1
  142. else:
  143. start_iter = 0
  144. trainer.train(start_iter, self.total_step)
  145. def do_test(self, model):
  146. evaluator = DetEvaluator(
  147. list(self.cfg.train.classes_id_map.keys()),
  148. self.cfg.train.work_dir,
  149. distributed=self.cfg.train.num_gpus > 1)
  150. test_loader = build_detection_test_loader(
  151. self.test_dataset,
  152. mapper=VideoDetMapper(
  153. self.cfg.train.classes_id_map, is_train=False),
  154. num_workers=self.cfg.evaluation.dataloader.workers_per_gpu)
  155. result = inference_on_dataset(model, test_loader, evaluator)
  156. print_csv_format(result)
  157. return result
  158. def get_or_download_model_dir(self, model, model_revision=None):
  159. if os.path.exists(model):
  160. model_cache_dir = model if os.path.isdir(
  161. model) else os.path.dirname(model)
  162. check_local_model_is_latest(
  163. model_cache_dir, user_agent={Invoke.KEY: Invoke.LOCAL_TRAINER})
  164. else:
  165. model_cache_dir = snapshot_download(
  166. model,
  167. revision=model_revision,
  168. user_agent={Invoke.KEY: Invoke.TRAINER})
  169. return model_cache_dir