loader.py 1.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. # Used to load and initialize polyfill handlers when importing torch._dynamo
  2. # Please add a new import when adding a new polyfill module.
  3. import importlib
  4. from typing import TYPE_CHECKING
  5. from .. import polyfills, trace_rules
  6. if TYPE_CHECKING:
  7. from types import ModuleType
  8. # See also the TYPE_CHECKING block in torch/_dynamo/polyfills/__init__.py
  9. POLYFILLED_MODULE_NAMES: tuple[str, ...] = (
  10. "_collections",
  11. "builtins",
  12. "functools",
  13. "itertools",
  14. "operator",
  15. "os",
  16. "pytree",
  17. "struct",
  18. "sys",
  19. "fx",
  20. "tensor",
  21. )
  22. POLYFILLED_MODULES: tuple["ModuleType", ...] = tuple(
  23. importlib.import_module(f".{submodule}", package=polyfills.__name__)
  24. for submodule in POLYFILLED_MODULE_NAMES
  25. )
  26. # Unregister the builtin functions from _builtin_function_ids to let them to be
  27. # dispatched with the appropriate VariableTracker type. Otherwise, they will be
  28. # dispatched with BuiltinVariable if present in _builtin_function_ids.
  29. for polyfill_module in POLYFILLED_MODULES:
  30. for polyfill_name in polyfill_module.__all__:
  31. polyfill_handler = getattr(polyfill_module, polyfill_name)
  32. original_fn = polyfill_handler.__torch_dynamo_original__
  33. trace_rules._builtin_function_ids.remove(id(original_fn))