utils.py 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. # mypy: allow-untyped-defs
  2. from torch.ao.quantization.pt2e.utils import _is_sym_size_node
  3. from torch.ao.quantization.quantizer.quantizer import QuantizationAnnotation
  4. from torch.fx import Node
  5. def _annotate_input_qspec_map(node: Node, input_node: Node, qspec):
  6. quantization_annotation = node.meta.get(
  7. "quantization_annotation", QuantizationAnnotation()
  8. )
  9. if quantization_annotation.input_qspec_map is None:
  10. quantization_annotation.input_qspec_map = {}
  11. quantization_annotation.input_qspec_map[input_node] = qspec
  12. node.meta["quantization_annotation"] = quantization_annotation
  13. def _annotate_output_qspec(node: Node, qspec):
  14. quantization_annotation = node.meta.get(
  15. "quantization_annotation", QuantizationAnnotation()
  16. )
  17. quantization_annotation.output_qspec = qspec
  18. node.meta["quantization_annotation"] = quantization_annotation
  19. def _node_only_used_for_sym_size(node: Node, partition_nodes: list[Node]):
  20. """
  21. This utility is used to handle cases when dynami_shape=True tracing leads
  22. to symint nodes in the pattern of linear module. In those cases, we need to
  23. distinguish between the nodes that are in input for just extracting value of
  24. some dimensions (and symint nodes) vs. the one that is activation.
  25. For example:
  26. graph(x, y, weight):
  27. size_0 = torch.ops.aten.sym_size([x], [0])
  28. size_1 = torch.ops.aten.sym_size([y], [1])
  29. view_size = size_0 * size_1
  30. size_3 = torch.ops.aten.sym_size([x], [2])
  31. vie_out = torch.ops.aten.view(x, [view_size, size_3])
  32. return mm(view_out, weight)
  33. In the example above y node is not actual input. It exist only to extract size_1
  34. """
  35. if _is_sym_size_node(node):
  36. return True
  37. return all(
  38. ((user not in partition_nodes) or _is_sym_size_node(user))
  39. for user in node.users
  40. )
  41. def _get_module_name_filter(module_name: str):
  42. """Get the module_name_filter function for a given module name, the filter accepts
  43. a node and checks if the node comes from a module that has certain module name
  44. For example:
  45. node: linear_op = call_function[...](...) # comes from a module with name blocks.sub.linear1
  46. >> module_name_filter = _get_module_name_filter("blocks.sub")
  47. >> print(module_name_filter(node))
  48. True # the node is from "blocks.sub" based on the fully qualified name "blocks.sub.linear1"
  49. """
  50. def module_name_filter(n: Node) -> bool:
  51. # example: {
  52. # 'L__self___sub': ("L['self'].sub", <class '....Sub'>),
  53. # 'L__self___sub_linear': ("L['self'].sub.linear", <class 'torch.nn.modules.linear.Linear'>)
  54. # }
  55. # get_attr nodes doesn't have nn_module_stack?
  56. nn_module_stack = n.meta.get("nn_module_stack", {})
  57. def _normalize_path(n):
  58. prefix = 0
  59. # TODO This is non standard behavior and should be removed when we migrate off capture_pre_autograd_graph.
  60. if n.startswith("L['self']."):
  61. prefix = len("L['self'].")
  62. return n[prefix:]
  63. names = [_normalize_path(n) for n, _ in nn_module_stack.values()]
  64. return module_name in names
  65. return module_name_filter