param_fetch.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. from typing import Any, Callable
  2. import torch
  3. import torch.nn as nn
  4. from torch.fx._compatibility import compatibility
  5. from torch.fx.graph_module import GraphModule
  6. __all__ = [
  7. "default_matching",
  8. "extract_attrs_for_lowering",
  9. "lift_lowering_attrs_to_nodes",
  10. ]
  11. # Matching method matches the attribute name of current version to the attribute name of `target_version`
  12. @compatibility(is_backward_compatible=False)
  13. def default_matching(name: str, target_version: int) -> str:
  14. """Default matching method"""
  15. return name
  16. # This dict maps the nn.Module class name to the attribute name list that we want to fetch for lowering.
  17. # The first integer in the tuple is the version number of the nn.Module class when we create the parameter list.
  18. # If there's a version mismatch then it means the parameter names in the book might be mismatched with nn.Module.
  19. module_fetch_book: dict[type, tuple[int, list[str], Callable[[str, int], str]]] = {
  20. torch.nn.modules.linear.Linear: (1, ["weight", "bias"], default_matching),
  21. torch.nn.modules.conv.Conv2d: (
  22. 1,
  23. [
  24. "weight",
  25. "bias",
  26. "kernel_size",
  27. "stride",
  28. "padding",
  29. "dilation",
  30. "groups",
  31. "padding_mode",
  32. ],
  33. default_matching,
  34. ),
  35. torch.nn.modules.batchnorm.BatchNorm2d: (
  36. 2,
  37. ["weight", "bias", "running_mean", "running_var", "eps"],
  38. default_matching,
  39. ),
  40. torch.nn.modules.pooling.AdaptiveAvgPool2d: (1, [], default_matching),
  41. torch.nn.modules.pooling.MaxPool2d: (
  42. 1,
  43. ["kernel_size", "stride", "padding", "dilation", "return_indices", "ceil_mode"],
  44. default_matching,
  45. ),
  46. torch.nn.modules.activation.ReLU: (1, ["inplace"], default_matching),
  47. }
  48. @compatibility(is_backward_compatible=False)
  49. def extract_attrs_for_lowering(mod: nn.Module) -> dict[str, Any]:
  50. """If `mod` is in `module_fetch_book`, fetch the mod's attributes that in the `module_fetch_book`
  51. after checking module's version is compatible with the `module_fetch_book`.
  52. """
  53. attrs_for_lowering: dict[str, Any] = {}
  54. attrs_for_lowering["name"] = torch.typename(mod)
  55. if type(mod) in module_fetch_book:
  56. version, param_to_fetch, matching_method = module_fetch_book[type(mod)]
  57. if version < mod._version:
  58. raise RuntimeError(
  59. f"Fetcher version {version} try to fetch {torch.typename(mod)} version {mod._version}, "
  60. "please upgrade the module_fetch_book, open an issue and @842974287 "
  61. "or report a bug to AIACC team directly."
  62. )
  63. for attr in param_to_fetch:
  64. attrs_for_lowering[attr] = getattr(mod, matching_method(attr, mod._version))
  65. else:
  66. raise RuntimeError(
  67. f"{torch.typename(mod)} is not in the module_fetch_book yet, "
  68. "please add it to the module_fetch_book, open an issue and @842974287 "
  69. "or report a bug to AIACC team directly."
  70. )
  71. return attrs_for_lowering
  72. @compatibility(is_backward_compatible=False)
  73. def lift_lowering_attrs_to_nodes(fx_module: GraphModule) -> None:
  74. """Recursively traverse all `fx_module` nodes and fetch the module's attributes if the node is a leaf module."""
  75. submodules = dict(fx_module.named_modules())
  76. for node in fx_module.graph.nodes:
  77. if node.op == "call_module":
  78. if isinstance(submodules[node.target], GraphModule):
  79. lift_lowering_attrs_to_nodes(submodules[node.target])
  80. else:
  81. node.attrs_for_lowering = extract_attrs_for_lowering(
  82. submodules[node.target]
  83. )