fuse_modules.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. # mypy: allow-untyped-defs
  2. import copy
  3. from typing import Optional
  4. import torch.nn as nn
  5. # for backward compatibility
  6. from torch.ao.quantization.fuser_method_mappings import ( # noqa: F401 # noqa: F401
  7. fuse_conv_bn,
  8. fuse_conv_bn_relu,
  9. get_fuser_method,
  10. )
  11. from torch.nn.utils.parametrize import type_before_parametrizations
  12. __all__ = [
  13. "fuse_known_modules",
  14. "fuse_modules",
  15. "fuse_modules_qat",
  16. ]
  17. # Generalization of getattr
  18. def _get_module(model, submodule_key):
  19. tokens = submodule_key.split(".")
  20. cur_mod = model
  21. for s in tokens:
  22. cur_mod = getattr(cur_mod, s)
  23. return cur_mod
  24. # Generalization of setattr
  25. def _set_module(model, submodule_key, module):
  26. tokens = submodule_key.split(".")
  27. sub_tokens = tokens[:-1]
  28. cur_mod = model
  29. for s in sub_tokens:
  30. cur_mod = getattr(cur_mod, s)
  31. setattr(cur_mod, tokens[-1], module)
  32. def fuse_known_modules(mod_list, is_qat, additional_fuser_method_mapping=None):
  33. r"""Return a list of known fuse modules.
  34. Returns a list of modules that fuses the operations specified
  35. in the input module list.
  36. Fuses only the following sequence of modules:
  37. conv, bn
  38. conv, bn, relu
  39. conv, relu
  40. linear, bn
  41. linear, relu
  42. For these sequences, the first element in the output module list performs
  43. the fused operation. The rest of the elements are set to nn.Identity()
  44. """
  45. types = tuple(type_before_parametrizations(m) for m in mod_list)
  46. fuser_method = get_fuser_method(types, additional_fuser_method_mapping)
  47. if fuser_method is None:
  48. raise NotImplementedError(f"Cannot fuse modules: {types}")
  49. new_mod: list[Optional[nn.Module]] = [None] * len(mod_list)
  50. fused = fuser_method(is_qat, *mod_list)
  51. # NOTE: forward hooks not processed in the two following for loops will be lost after the fusion
  52. # Move pre forward hooks of the base module to resulting fused module
  53. for pre_hook_fn in mod_list[0]._forward_pre_hooks.values():
  54. fused.register_forward_pre_hook(pre_hook_fn)
  55. mod_list[0]._forward_pre_hooks.clear()
  56. # Move post forward hooks of the last module to resulting fused module
  57. for hook_fn in mod_list[-1]._forward_hooks.values():
  58. fused.register_forward_hook(hook_fn)
  59. mod_list[-1]._forward_hooks.clear()
  60. new_mod[0] = fused
  61. for i in range(1, len(mod_list)):
  62. identity = nn.Identity()
  63. identity.training = mod_list[0].training
  64. new_mod[i] = identity
  65. return new_mod
  66. def _fuse_modules_helper(
  67. model,
  68. modules_to_fuse,
  69. is_qat,
  70. fuser_func=fuse_known_modules,
  71. fuse_custom_config_dict=None,
  72. ):
  73. if fuse_custom_config_dict is None:
  74. fuse_custom_config_dict = {}
  75. additional_fuser_method_mapping = fuse_custom_config_dict.get(
  76. "additional_fuser_method_mapping", {}
  77. )
  78. mod_list = [_get_module(model, item) for item in modules_to_fuse]
  79. # Fuse list of modules
  80. new_mod_list = fuser_func(mod_list, is_qat, additional_fuser_method_mapping)
  81. # Replace original module list with fused module list
  82. for i, item in enumerate(modules_to_fuse):
  83. _set_module(model, item, new_mod_list[i])
  84. def _fuse_modules(
  85. model,
  86. modules_to_fuse,
  87. is_qat,
  88. inplace=False,
  89. fuser_func=fuse_known_modules,
  90. fuse_custom_config_dict=None,
  91. ):
  92. if not inplace:
  93. model = copy.deepcopy(model)
  94. if all(isinstance(module_element, str) for module_element in modules_to_fuse):
  95. # Handle case of modules_to_fuse being a list
  96. _fuse_modules_helper(
  97. model, modules_to_fuse, is_qat, fuser_func, fuse_custom_config_dict
  98. )
  99. else:
  100. # Handle case of modules_to_fuse being a list of lists
  101. for module_list in modules_to_fuse:
  102. _fuse_modules_helper(
  103. model, module_list, is_qat, fuser_func, fuse_custom_config_dict
  104. )
  105. return model
  106. def fuse_modules(
  107. model,
  108. modules_to_fuse,
  109. inplace=False,
  110. fuser_func=fuse_known_modules,
  111. fuse_custom_config_dict=None,
  112. ):
  113. r"""Fuse a list of modules into a single module.
  114. Fuses only the following sequence of modules:
  115. conv, bn
  116. conv, bn, relu
  117. conv, relu
  118. linear, relu
  119. bn, relu
  120. All other sequences are left unchanged.
  121. For these sequences, replaces the first item in the list
  122. with the fused module, replacing the rest of the modules
  123. with identity.
  124. Args:
  125. model: Model containing the modules to be fused
  126. modules_to_fuse: list of list of module names to fuse. Can also be a list
  127. of strings if there is only a single list of modules to fuse.
  128. inplace: bool specifying if fusion happens in place on the model, by default
  129. a new model is returned
  130. fuser_func: Function that takes in a list of modules and outputs a list of fused modules
  131. of the same length. For example,
  132. fuser_func([convModule, BNModule]) returns the list [ConvBNModule, nn.Identity()]
  133. Defaults to torch.ao.quantization.fuse_known_modules
  134. `fuse_custom_config_dict`: custom configuration for fusion
  135. .. code-block:: python
  136. # Example of fuse_custom_config_dict
  137. fuse_custom_config_dict = {
  138. # Additional fuser_method mapping
  139. "additional_fuser_method_mapping": {
  140. (torch.nn.Conv2d, torch.nn.BatchNorm2d): fuse_conv_bn
  141. },
  142. }
  143. Returns:
  144. model with fused modules. A new copy is created if inplace=True.
  145. Examples::
  146. >>> # xdoctest: +SKIP
  147. >>> m = M().eval()
  148. >>> # m is a module containing the sub-modules below
  149. >>> modules_to_fuse = [ ['conv1', 'bn1', 'relu1'], ['submodule.conv', 'submodule.relu']]
  150. >>> fused_m = torch.ao.quantization.fuse_modules(m, modules_to_fuse)
  151. >>> output = fused_m(input)
  152. >>> m = M().eval()
  153. >>> # Alternately provide a single list of modules to fuse
  154. >>> modules_to_fuse = ['conv1', 'bn1', 'relu1']
  155. >>> fused_m = torch.ao.quantization.fuse_modules(m, modules_to_fuse)
  156. >>> output = fused_m(input)
  157. """
  158. return _fuse_modules(
  159. model,
  160. modules_to_fuse,
  161. is_qat=False,
  162. inplace=inplace,
  163. fuser_func=fuser_func,
  164. fuse_custom_config_dict=fuse_custom_config_dict,
  165. )
  166. def fuse_modules_qat(
  167. model,
  168. modules_to_fuse,
  169. inplace=False,
  170. fuser_func=fuse_known_modules,
  171. fuse_custom_config_dict=None,
  172. ):
  173. """QAT version for `fuse_modules`."""
  174. return _fuse_modules(
  175. model,
  176. modules_to_fuse,
  177. is_qat=True,
  178. inplace=inplace,
  179. fuser_func=fuser_func,
  180. fuse_custom_config_dict=fuse_custom_config_dict,
  181. )