utils.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import torch
  3. import torch.nn as nn
  4. from modelscope.utils.torch_utils import is_master
  5. class SparseBinarizer(torch.autograd.Function):
  6. @staticmethod
  7. def forward(ctx, mask_scores, sparsity):
  8. num_prune = int(mask_scores.numel() * sparsity)
  9. prune_indices = torch.argsort(mask_scores.reshape(-1))[:num_prune]
  10. mask = mask_scores.clone().fill_(1)
  11. mask.reshape(-1)[prune_indices] = 0.0
  12. return mask
  13. @staticmethod
  14. def backward(ctx, gradOutput):
  15. return gradOutput, None
  16. class SparseLinear(nn.Module):
  17. """
  18. Fully Connected layer with on the fly adaptive mask.
  19. """
  20. def __init__(
  21. self,
  22. module,
  23. pruning_method='pst',
  24. weight_rank=8,
  25. weight_beta=1.0,
  26. mask_rank=8,
  27. mask_alpha1=1.0,
  28. mask_alpha2=1.0,
  29. ):
  30. super(SparseLinear, self).__init__()
  31. self.module = module
  32. out_features = self.module.weight.shape[0]
  33. in_features = self.module.weight.shape[1]
  34. self.weight = self.module.weight
  35. self.module.weight = None
  36. self.module._parameters.pop('weight')
  37. self.pruning_method = pruning_method
  38. self.cur_sparsity = 0.0
  39. if self.pruning_method == 'pst':
  40. self.weight_rank = weight_rank
  41. self.weight_beta = weight_beta
  42. self.mask_rank = mask_rank
  43. self.mask_alpha1 = mask_alpha1
  44. self.mask_alpha2 = mask_alpha2
  45. # create trainable params
  46. self.weight_U = nn.Parameter(
  47. torch.randn(out_features, self.weight_rank).to(
  48. device=self.weight.device, dtype=self.weight.dtype))
  49. self.weight_V = nn.Parameter(
  50. torch.zeros(self.weight_rank, in_features).to(
  51. device=self.weight.device, dtype=self.weight.dtype))
  52. self.mask_scores_A = nn.Parameter(
  53. torch.randn(out_features, self.mask_rank).to(
  54. device=self.weight.device, dtype=self.weight.dtype))
  55. self.mask_scores_B = nn.Parameter(
  56. torch.zeros(self.mask_rank, in_features).to(
  57. device=self.weight.device, dtype=self.weight.dtype))
  58. self.mask_scores_R = nn.Parameter(
  59. torch.zeros(out_features).to(
  60. device=self.weight.device, dtype=self.weight.dtype))
  61. self.mask_scores_C = nn.Parameter(
  62. torch.zeros(in_features).to(
  63. device=self.weight.device, dtype=self.weight.dtype))
  64. self.weight.requires_grad = False
  65. if self.module.bias is not None:
  66. self.module.bias.requires_grad = False
  67. def forward(self, *inputs):
  68. if self.pruning_method == 'pst':
  69. weight = self.weight + self.weight_beta * self.weight_U @ self.weight_V
  70. mask_scores = (
  71. weight.abs()
  72. + self.mask_alpha1 * self.mask_scores_A @ self.mask_scores_B
  73. + self.mask_alpha2 * (self.mask_scores_R.unsqueeze(1)
  74. + self.mask_scores_C.unsqueeze(0)))
  75. mask = SparseBinarizer.apply(mask_scores, self.cur_sparsity)
  76. masked_weight = mask * weight
  77. self.module.weight = masked_weight
  78. return self.module(*inputs)
  79. else:
  80. return self.module(*inputs)
  81. def convert(self):
  82. if self.pruning_method == 'pst':
  83. weight = self.weight + self.weight_beta * self.weight_U @ self.weight_V
  84. mask_scores = (
  85. weight.abs()
  86. + self.mask_alpha1 * self.mask_scores_A @ self.mask_scores_B
  87. + self.mask_alpha2 * (self.mask_scores_R.unsqueeze(1)
  88. + self.mask_scores_C.unsqueeze(0)))
  89. mask = SparseBinarizer.apply(mask_scores, self.cur_sparsity)
  90. masked_weight = mask * weight
  91. self.module.weight = nn.Parameter(masked_weight.data)
  92. def _setattr(model, name, module):
  93. name_list = name.split('.')
  94. for name in name_list[:-1]:
  95. model = getattr(model, name)
  96. setattr(model, name_list[-1], module)
  97. def convert_sparse_network(
  98. model,
  99. pruning_method,
  100. weight_rank,
  101. weight_beta,
  102. mask_rank,
  103. mask_alpha1,
  104. mask_alpha2,
  105. logger=None,
  106. ):
  107. compress_module = [nn.Linear]
  108. try:
  109. from megatron_util import mpu
  110. compress_module.extend(
  111. [mpu.RowParallelLinear, mpu.ColumnParallelLinear])
  112. except ImportError:
  113. pass
  114. for name, module in model.named_modules():
  115. if type(module) in compress_module:
  116. new_module = SparseLinear(
  117. module,
  118. pruning_method,
  119. weight_rank,
  120. weight_beta,
  121. mask_rank,
  122. mask_alpha1,
  123. mask_alpha2,
  124. )
  125. # replace original module by new sparse module
  126. _setattr(model, name, new_module)
  127. if is_master():
  128. if logger:
  129. logger.info(f'convert {name} to sparse module.')
  130. else:
  131. print(f'convert {name} to sparse module.')
  132. def update_network_sparsity(model, sparsity):
  133. for name, module in model.named_modules():
  134. if isinstance(module, SparseLinear):
  135. module.cur_sparsity = sparsity
  136. def schedule_sparsity_ratio(
  137. step,
  138. total_step,
  139. frequency,
  140. initial_warmup,
  141. final_warmup,
  142. initial_sparsity,
  143. final_sparsity,
  144. ):
  145. if step <= initial_warmup * total_step:
  146. sparsity = initial_sparsity
  147. elif step > (total_step - final_warmup * total_step):
  148. sparsity = final_sparsity
  149. else:
  150. spars_warmup_steps = initial_warmup * total_step
  151. spars_schedu_steps = (final_warmup + initial_warmup) * total_step
  152. step = (step - spars_warmup_steps) // frequency * frequency
  153. mul_coeff = 1 - step / (total_step - spars_schedu_steps)
  154. sparsity = final_sparsity + (initial_sparsity - final_sparsity) * (
  155. mul_coeff**3)
  156. return sparsity
  157. def generate_sparse_model(model, logger=None):
  158. # generate sparse weight for saving
  159. for name, module in model.named_modules():
  160. if isinstance(module, SparseLinear):
  161. module.convert()
  162. _setattr(model, name, module.module)
  163. if is_master():
  164. if logger:
  165. logger.info(f'convert {name} weight to sparse weight, \
  166. sparsity ratio={torch.mean(1.0*(module.module.weight==0)).item()}.'
  167. )
  168. else:
  169. print(f'convert {name} weight to sparse, \
  170. sparsity ratio={torch.mean(1.0*(module.module.weight==0)).item()}.'
  171. )