graph_manipulation.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. # mypy: allow-untyped-defs
  2. from typing import Any, NamedTuple, Optional
  3. import torch
  4. from torch.fx._compatibility import compatibility
  5. from torch.fx.graph import Graph
  6. from torch.fx.graph_module import GraphModule
  7. from torch.fx.node import map_arg, Node, Target
  8. from torch.fx.passes.shape_prop import ShapeProp
  9. __all__ = [
  10. "replace_target_nodes_with",
  11. "size_bytes",
  12. "get_size_of_all_nodes",
  13. "get_tensor_meta",
  14. "get_size_of_node",
  15. ]
  16. @compatibility(is_backward_compatible=False)
  17. def replace_target_nodes_with(
  18. fx_module: GraphModule,
  19. old_op: str,
  20. old_target: Target,
  21. new_op: str,
  22. new_target: Target,
  23. ):
  24. """Modifies all nodes in fx_module.graph.nodes which match the specified op code and target,
  25. and updates them to match the new op code and target"""
  26. new_graph = Graph()
  27. val_map: dict[Node, Node] = {}
  28. for node in fx_module.graph.nodes:
  29. if node.op == old_op and node.target == old_target:
  30. args = map_arg(node.args, lambda n: val_map[n])
  31. kwargs = map_arg(node.kwargs, lambda n: val_map[n])
  32. assert isinstance(args, tuple)
  33. assert isinstance(kwargs, dict)
  34. val_map[node] = new_graph.create_node(
  35. new_op, new_target, args, kwargs, node.name
  36. )
  37. else:
  38. val_map[node] = new_graph.node_copy(node, lambda n: val_map[n])
  39. fx_module.graph = new_graph
  40. @compatibility(is_backward_compatible=False)
  41. class size_bytes(NamedTuple):
  42. output_size: int
  43. total_size: int
  44. @compatibility(is_backward_compatible=False)
  45. def get_size_of_all_nodes(
  46. fx_module: GraphModule, args: Optional[list[torch.Tensor]] = None
  47. ) -> None:
  48. """Given a fx graph module, update each node with its total size (weights + bias + output)
  49. and its output_size(output). For a non-module node, the total size is the output size.
  50. return total size"""
  51. if args is not None:
  52. # Mark shape and dtype for each node (node.shape and node.dtype)
  53. ShapeProp(fx_module).propagate(*args)
  54. # Calculate the total size of the whole fx graph
  55. for node in fx_module.graph.nodes:
  56. if node.op == "output":
  57. break
  58. node.size_bytes = get_size_of_node(fx_module, node)
  59. return
  60. @compatibility(is_backward_compatible=False)
  61. def get_tensor_meta(node: Node) -> Any:
  62. tensor_meta = node.meta.get("tensor_meta")
  63. if not tensor_meta:
  64. raise RuntimeError(
  65. f"Node {node} has no tensor metadata associated with it! "
  66. f"Check that shape propagation has run."
  67. )
  68. return tensor_meta
  69. @compatibility(is_backward_compatible=False)
  70. def get_size_of_node(fx_module: GraphModule, node: Node) -> size_bytes:
  71. """Given a node with node.dtype and node.shape, return its total size and its output size.
  72. total_size = weights + bias + output_size
  73. """
  74. # Total num of elements
  75. total_num_of_elems = 0
  76. # For a module, consider all parameters
  77. if node.op == "call_module":
  78. submodule_dict = dict(fx_module.named_modules())
  79. submodule = submodule_dict[node.target]
  80. parameters = submodule.named_parameters()
  81. # Parameters are named tuples
  82. for _name, p in parameters:
  83. total_num_of_elems += p.numel()
  84. # Don't forget the output size
  85. # node.shape is the shape of this node's output
  86. tensor_meta = get_tensor_meta(node)
  87. output_elem = tensor_meta.shape.numel()
  88. total_num_of_elems += output_elem
  89. # Assume for now if it's quantized then it's qint8 or quint8
  90. if tensor_meta.is_quantized:
  91. size_per_elem_bytes = torch._empty_affine_quantized(
  92. [], dtype=tensor_meta.dtype
  93. ).element_size()
  94. else:
  95. size_per_elem_bytes = torch.tensor([], dtype=tensor_meta.dtype).element_size()
  96. total_size = size_per_elem_bytes * total_num_of_elems
  97. output_size = size_per_elem_bytes * output_elem
  98. return size_bytes(output_size, total_size)