_prune.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. import os
  2. import pkgutil
  3. from copy import deepcopy
  4. from torch import nn as nn
  5. from timm.layers import Conv2dSame, BatchNormAct2d, Linear
  6. __all__ = ['extract_layer', 'set_layer', 'adapt_model_from_string', 'adapt_model_from_file']
  7. def extract_layer(model, layer):
  8. layer = layer.split('.')
  9. module = model
  10. if hasattr(model, 'module') and layer[0] != 'module':
  11. module = model.module
  12. if not hasattr(model, 'module') and layer[0] == 'module':
  13. layer = layer[1:]
  14. for l in layer:
  15. if hasattr(module, l):
  16. if not l.isdigit():
  17. module = getattr(module, l)
  18. else:
  19. module = module[int(l)]
  20. else:
  21. return module
  22. return module
  23. def set_layer(model, layer, val):
  24. layer = layer.split('.')
  25. module = model
  26. if hasattr(model, 'module') and layer[0] != 'module':
  27. module = model.module
  28. lst_index = 0
  29. module2 = module
  30. for l in layer:
  31. if hasattr(module2, l):
  32. if not l.isdigit():
  33. module2 = getattr(module2, l)
  34. else:
  35. module2 = module2[int(l)]
  36. lst_index += 1
  37. lst_index -= 1
  38. for l in layer[:lst_index]:
  39. if not l.isdigit():
  40. module = getattr(module, l)
  41. else:
  42. module = module[int(l)]
  43. l = layer[lst_index]
  44. setattr(module, l, val)
  45. def adapt_model_from_string(parent_module, model_string):
  46. separator = '***'
  47. state_dict = {}
  48. lst_shape = model_string.split(separator)
  49. for k in lst_shape:
  50. k = k.split(':')
  51. key = k[0]
  52. shape = k[1][1:-1].split(',')
  53. if shape[0] != '':
  54. state_dict[key] = [int(i) for i in shape]
  55. # Extract device and dtype from the parent module
  56. device = next(parent_module.parameters()).device
  57. dtype = next(parent_module.parameters()).dtype
  58. dd = {'device': device, 'dtype': dtype}
  59. new_module = deepcopy(parent_module)
  60. for n, m in parent_module.named_modules():
  61. old_module = extract_layer(parent_module, n)
  62. if isinstance(old_module, nn.Conv2d) or isinstance(old_module, Conv2dSame):
  63. if isinstance(old_module, Conv2dSame):
  64. conv = Conv2dSame
  65. else:
  66. conv = nn.Conv2d
  67. s = state_dict[n + '.weight']
  68. in_channels = s[1]
  69. out_channels = s[0]
  70. g = 1
  71. if old_module.groups > 1:
  72. in_channels = out_channels
  73. g = in_channels
  74. new_conv = conv(
  75. in_channels=in_channels,
  76. out_channels=out_channels,
  77. kernel_size=old_module.kernel_size,
  78. bias=old_module.bias is not None,
  79. padding=old_module.padding,
  80. dilation=old_module.dilation,
  81. groups=g,
  82. stride=old_module.stride,
  83. **dd,
  84. )
  85. set_layer(new_module, n, new_conv)
  86. elif isinstance(old_module, BatchNormAct2d):
  87. new_bn = BatchNormAct2d(
  88. state_dict[n + '.weight'][0],
  89. eps=old_module.eps,
  90. momentum=old_module.momentum,
  91. affine=old_module.affine,
  92. track_running_stats=True,
  93. **dd,
  94. )
  95. new_bn.drop = old_module.drop
  96. new_bn.act = old_module.act
  97. set_layer(new_module, n, new_bn)
  98. elif isinstance(old_module, nn.BatchNorm2d):
  99. new_bn = nn.BatchNorm2d(
  100. num_features=state_dict[n + '.weight'][0],
  101. eps=old_module.eps,
  102. momentum=old_module.momentum,
  103. affine=old_module.affine,
  104. track_running_stats=True,
  105. **dd,
  106. )
  107. set_layer(new_module, n, new_bn)
  108. elif isinstance(old_module, nn.Linear):
  109. # FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer?
  110. num_features = state_dict[n + '.weight'][1]
  111. new_fc = Linear(
  112. in_features=num_features,
  113. out_features=old_module.out_features,
  114. bias=old_module.bias is not None,
  115. **dd,
  116. )
  117. set_layer(new_module, n, new_fc)
  118. if hasattr(new_module, 'num_features'):
  119. if getattr(new_module, 'head_hidden_size', 0) == new_module.num_features:
  120. new_module.head_hidden_size = num_features
  121. new_module.num_features = num_features
  122. new_module.eval()
  123. parent_module.eval()
  124. return new_module
  125. def adapt_model_from_file(parent_module, model_variant):
  126. adapt_data = pkgutil.get_data(__name__, os.path.join('_pruned', model_variant + '.txt'))
  127. return adapt_model_from_string(parent_module, adapt_data.decode('utf-8').strip())