rewriter.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. # mypy: allow-untyped-decorators
  2. # mypy: allow-untyped-defs
  3. import ast
  4. import copy
  5. import functools
  6. import inspect
  7. import textwrap
  8. from types import FunctionType
  9. from typing import Any, Callable, cast, Optional, Union
  10. import torch
  11. from torch._sources import normalize_source_lines
  12. from torch.fx._symbolic_trace import Tracer
  13. from torch.fx.graph import Graph
  14. class AST_Rewriter(ast.NodeTransformer):
  15. """
  16. Take a FunctionType object representing a `forward` method, then
  17. perform an AST rewrite to swap out nodes that are not symbolically
  18. traceable with a callsite to the FX alternative.
  19. To support swapping out an AST node, define a new `visit` method on
  20. that node. For more details, see:
  21. https://docs.python.org/3/library/ast.html#ast.NodeTransformer
  22. """
  23. # This function checks for new keys added in the globals dict. TorchDynamo
  24. # can insert new keys in the global dict and upset the check. Therefore, put
  25. # a disable here. This function is an optimization pass and not really
  26. # suitable for dynamo tracing anyways.
  27. @torch._dynamo.disable
  28. def rewrite(self, fn: FunctionType):
  29. # Normalize the source lines
  30. sourcelines, _ = inspect.getsourcelines(fn)
  31. sourcelines = normalize_source_lines(sourcelines)
  32. source = "".join(sourcelines)
  33. normalized_str = textwrap.dedent(source)
  34. # Rewrite the original AST
  35. source_ast = ast.parse(normalized_str)
  36. dest_ast = ast.fix_missing_locations(self.visit(source_ast))
  37. # Pull out the compiled function from the newly-created Module
  38. code = compile(dest_ast, "", "exec")
  39. globals_dict = copy.copy(fn.__globals__)
  40. keys_before = set(globals_dict.keys())
  41. exec(code, globals_dict)
  42. new_keys = list(set(globals_dict.keys()) - keys_before)
  43. assert len(new_keys) == 1
  44. fn_compiled = globals_dict[new_keys[0]]
  45. # return the compiled function with the original globals
  46. def change_func_globals(f, globals):
  47. """Based on https://stackoverflow.com/a/13503277/2988730 (@unutbu)"""
  48. # __globals__ is a private member of the function class
  49. # so we have to copy the function, f, all of its member, except f.__globals__
  50. g = FunctionType(
  51. f.__code__,
  52. globals,
  53. name=f.__name__,
  54. argdefs=f.__defaults__,
  55. closure=f.__closure__,
  56. )
  57. g = functools.update_wrapper(g, f)
  58. g.__kwdefaults__ = copy.copy(f.__kwdefaults__) # type:ignore[attr-defined]
  59. return g
  60. # Return the correct FunctionType object
  61. return change_func_globals(fn_compiled, globals=fn.__globals__)
  62. def visit_Assert(self, node):
  63. """
  64. Swap out the Assert node (Python's `assert`) with a callsite to the
  65. symbolically-traceable torch._assert function
  66. """
  67. # Create the Call node
  68. n = ast.parse("torch._assert()", mode="eval")
  69. assert isinstance(n, ast.Expression)
  70. call_node = n.body
  71. assert isinstance(call_node, ast.Call)
  72. msg = node.msg if node.msg else ast.Constant(value="", kind=None)
  73. call_node.args = [node.test, msg]
  74. # Ensure that the new node conforms to the Python AST grammar
  75. expr_wrapper = ast.Expr(value=call_node)
  76. # Return the new Call node to signify that we want to use it as
  77. # a replacement for the original _assert node
  78. return ast.copy_location(expr_wrapper, node)
  79. def visit_AnnAssign(self, node):
  80. """
  81. Swap out Python's AnnAssign with an Assign node where the annotation function is called.
  82. Example:
  83. Original:
  84. y: Tensor_Type(1,2,3, Dyn) = f2(x)
  85. Output:
  86. y = annotate(f2(x),Tensor_Type((1,2,3,Dyn)))
  87. """
  88. return ast.Assign(
  89. targets=[node.target],
  90. value=ast.Call(
  91. func=ast.Name(id="annotate", ctx=ast.Load()),
  92. args=[node.value, node.annotation],
  93. keywords=[],
  94. ),
  95. )
  96. class RewritingTracer(Tracer):
  97. def trace(
  98. self,
  99. root: Union[torch.nn.Module, Callable],
  100. concrete_args: Optional[dict[str, Any]] = None,
  101. ) -> Graph:
  102. return super().trace(_rewrite(root), concrete_args)
  103. def _rewrite(fn: Union[torch.nn.Module, Callable]) -> Union[torch.nn.Module, Callable]:
  104. if isinstance(fn, torch.nn.Module):
  105. # Rewrite this module's `forward` as well as the `forward`s of
  106. # all of this module's recursive descendents. Return the new,
  107. # rewritten module hierarchy.
  108. def rewrite_module(m: torch.nn.Module):
  109. class RewrittenModule(torch.nn.Module):
  110. def __init__(self, orig):
  111. super().__init__()
  112. for k, v in orig.__dict__.items():
  113. if isinstance(v, torch.nn.Module):
  114. self.__dict__[k] = copy.copy(rewrite_module(v))
  115. else:
  116. self.__dict__[k] = copy.copy(v)
  117. RewrittenModule.forward = AST_Rewriter().rewrite(
  118. cast(FunctionType, m.forward)
  119. )
  120. return RewrittenModule(m)
  121. return rewrite_module(fn)
  122. else:
  123. # Rewrite this single free function
  124. return AST_Rewriter().rewrite(cast(FunctionType, fn))