| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- import os
- from modelscope.metainfo import Hooks
- from modelscope.trainers.hooks.builder import HOOKS
- from modelscope.trainers.hooks.hook import Hook
- from modelscope.trainers.hooks.priority import Priority
- from modelscope.utils.checkpoint import save_checkpoint
- from modelscope.utils.torch_utils import is_master
- @HOOKS.register_module(module_name=Hooks.SparsityHook)
- class SparsityHook(Hook):
- PRIORITY = Priority.HIGHEST
- def __init__(self, pruning_method, config={}, save_dir=None):
- self.pruning_method = pruning_method
- self.save_dir = save_dir
- self.compress_module = config.get('compress_module', [])
- self.weight_rank = config.get('weight_rank', 8)
- self.weight_beta = config.get('weight_beta', 1)
- self.mask_rank = config.get('mask_rank', 8)
- self.mask_alpha1 = config.get('mask_alpha1', 1)
- self.mask_alpha2 = config.get('mask_alpha2', 1)
- self.step = 0
- self.total_step = 0
- self.frequency = config.get('frequency', 1)
- self.initial_warmup = config.get('initial_warmup', 0.1)
- self.final_warmup = config.get('final_warmup', 0.3)
- self.initial_sparsity = config.get('initial_sparsity', 0.0)
- self.final_sparsity = config.get('final_sparsity', 0.0)
- def before_run(self, trainer):
- import torch
- from .utils import SparseLinear, convert_sparse_network
- if self.save_dir is None:
- self.save_dir = trainer.work_dir
- if len(self.compress_module) == 0:
- convert_sparse_network(
- trainer.model,
- pruning_method=self.pruning_method,
- weight_rank=self.weight_rank,
- weight_beta=self.weight_beta,
- mask_rank=self.mask_rank,
- mask_alpha1=self.mask_alpha1,
- mask_alpha2=self.mask_alpha2,
- logger=trainer.logger,
- )
- else:
- for cm in self.compress_module:
- for name, module in trainer.model.named_modules():
- if name != cm:
- continue
- convert_sparse_network(
- module,
- pruning_method=self.pruning_method,
- weight_rank=self.weight_rank,
- weight_beta=self.weight_beta,
- mask_rank=self.mask_rank,
- mask_alpha1=self.mask_alpha1,
- mask_alpha2=self.mask_alpha2,
- logger=trainer.logger,
- )
- for i in range(len(trainer.optimizer.param_groups)):
- new_train_params = []
- for param in trainer.optimizer.param_groups[i]['params']:
- is_find = False
- for name, module in trainer.model.named_modules():
- if isinstance(module, SparseLinear):
- if torch.equal(param.half(),
- module.weight.data.half()):
- is_find = True
- break
- if not is_find:
- new_train_params.append(param)
- trainer.optimizer.param_groups[i]['params'] = new_train_params
- new_params = []
- for name, module in trainer.model.named_modules():
- if isinstance(module, SparseLinear):
- new_params.extend(
- [p for p in module.parameters() if p.requires_grad])
- trainer.optimizer.add_param_group({'params': new_params})
- self.total_step = trainer.iters_per_epoch * trainer._max_epochs
- def before_train_iter(self, trainer):
- from .utils import schedule_sparsity_ratio, update_network_sparsity
- cur_sparsity = schedule_sparsity_ratio(
- self.step,
- self.total_step,
- self.frequency,
- self.initial_warmup,
- self.final_warmup,
- self.initial_sparsity,
- self.final_sparsity,
- )
- update_network_sparsity(trainer.model, cur_sparsity)
- if is_master():
- trainer.logger.info(
- f'Step[{self.step}/{self.total_step}] current sparsity ratio = {cur_sparsity}'
- )
- self.step += 1
- def after_run(self, trainer):
- from .utils import generate_sparse_model
- generate_sparse_model(trainer.model, logger=trainer.logger)
- self._save_checkpoint(trainer)
- def _save_checkpoint(self, trainer):
- if is_master():
- trainer.logger.info('Saving checkpoint at final compress')
- cur_save_name = os.path.join(self.save_dir, 'compress_model.pth')
- save_checkpoint(trainer.model, cur_save_name, trainer.optimizer)
|