sparsity_hook.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. from modelscope.metainfo import Hooks
  4. from modelscope.trainers.hooks.builder import HOOKS
  5. from modelscope.trainers.hooks.hook import Hook
  6. from modelscope.trainers.hooks.priority import Priority
  7. from modelscope.utils.checkpoint import save_checkpoint
  8. from modelscope.utils.torch_utils import is_master
  9. @HOOKS.register_module(module_name=Hooks.SparsityHook)
  10. class SparsityHook(Hook):
  11. PRIORITY = Priority.HIGHEST
  12. def __init__(self, pruning_method, config={}, save_dir=None):
  13. self.pruning_method = pruning_method
  14. self.save_dir = save_dir
  15. self.compress_module = config.get('compress_module', [])
  16. self.weight_rank = config.get('weight_rank', 8)
  17. self.weight_beta = config.get('weight_beta', 1)
  18. self.mask_rank = config.get('mask_rank', 8)
  19. self.mask_alpha1 = config.get('mask_alpha1', 1)
  20. self.mask_alpha2 = config.get('mask_alpha2', 1)
  21. self.step = 0
  22. self.total_step = 0
  23. self.frequency = config.get('frequency', 1)
  24. self.initial_warmup = config.get('initial_warmup', 0.1)
  25. self.final_warmup = config.get('final_warmup', 0.3)
  26. self.initial_sparsity = config.get('initial_sparsity', 0.0)
  27. self.final_sparsity = config.get('final_sparsity', 0.0)
  28. def before_run(self, trainer):
  29. import torch
  30. from .utils import SparseLinear, convert_sparse_network
  31. if self.save_dir is None:
  32. self.save_dir = trainer.work_dir
  33. if len(self.compress_module) == 0:
  34. convert_sparse_network(
  35. trainer.model,
  36. pruning_method=self.pruning_method,
  37. weight_rank=self.weight_rank,
  38. weight_beta=self.weight_beta,
  39. mask_rank=self.mask_rank,
  40. mask_alpha1=self.mask_alpha1,
  41. mask_alpha2=self.mask_alpha2,
  42. logger=trainer.logger,
  43. )
  44. else:
  45. for cm in self.compress_module:
  46. for name, module in trainer.model.named_modules():
  47. if name != cm:
  48. continue
  49. convert_sparse_network(
  50. module,
  51. pruning_method=self.pruning_method,
  52. weight_rank=self.weight_rank,
  53. weight_beta=self.weight_beta,
  54. mask_rank=self.mask_rank,
  55. mask_alpha1=self.mask_alpha1,
  56. mask_alpha2=self.mask_alpha2,
  57. logger=trainer.logger,
  58. )
  59. for i in range(len(trainer.optimizer.param_groups)):
  60. new_train_params = []
  61. for param in trainer.optimizer.param_groups[i]['params']:
  62. is_find = False
  63. for name, module in trainer.model.named_modules():
  64. if isinstance(module, SparseLinear):
  65. if torch.equal(param.half(),
  66. module.weight.data.half()):
  67. is_find = True
  68. break
  69. if not is_find:
  70. new_train_params.append(param)
  71. trainer.optimizer.param_groups[i]['params'] = new_train_params
  72. new_params = []
  73. for name, module in trainer.model.named_modules():
  74. if isinstance(module, SparseLinear):
  75. new_params.extend(
  76. [p for p in module.parameters() if p.requires_grad])
  77. trainer.optimizer.add_param_group({'params': new_params})
  78. self.total_step = trainer.iters_per_epoch * trainer._max_epochs
  79. def before_train_iter(self, trainer):
  80. from .utils import schedule_sparsity_ratio, update_network_sparsity
  81. cur_sparsity = schedule_sparsity_ratio(
  82. self.step,
  83. self.total_step,
  84. self.frequency,
  85. self.initial_warmup,
  86. self.final_warmup,
  87. self.initial_sparsity,
  88. self.final_sparsity,
  89. )
  90. update_network_sparsity(trainer.model, cur_sparsity)
  91. if is_master():
  92. trainer.logger.info(
  93. f'Step[{self.step}/{self.total_step}] current sparsity ratio = {cur_sparsity}'
  94. )
  95. self.step += 1
  96. def after_run(self, trainer):
  97. from .utils import generate_sparse_model
  98. generate_sparse_model(trainer.model, logger=trainer.logger)
  99. self._save_checkpoint(trainer)
  100. def _save_checkpoint(self, trainer):
  101. if is_master():
  102. trainer.logger.info('Saving checkpoint at final compress')
  103. cur_save_name = os.path.join(self.save_dir, 'compress_model.pth')
  104. save_checkpoint(trainer.model, cur_save_name, trainer.optimizer)