fx.py 1.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940
  1. from typing import Any, Callable
  2. from torch._C import _fx_map_aggregate, _fx_map_arg
  3. from torch.fx.immutable_collections import immutable_dict, immutable_list
  4. from torch.fx.node import Node
  5. from ..decorators import substitute_in_graph
  6. @substitute_in_graph(_fx_map_arg, can_constant_fold_through=True)
  7. def map_arg(a: Any, fn: Callable[[Node], Any]) -> Any:
  8. return map_aggregate(a, lambda x: fn(x) if isinstance(x, Node) else x)
  9. @substitute_in_graph(_fx_map_aggregate, can_constant_fold_through=True)
  10. def map_aggregate(a: Any, fn: Callable[[Any], Any]) -> Any:
  11. result: Any
  12. if isinstance(a, tuple):
  13. it = (map_aggregate(elem, fn) for elem in a)
  14. # Support NamedTuple (if it has `_fields`) by repacking into original type.
  15. result = type(a)(*it) if hasattr(a, "_fields") else tuple(it)
  16. elif isinstance(a, list):
  17. result = immutable_list([map_aggregate(elem, fn) for elem in a])
  18. elif isinstance(a, dict):
  19. result = immutable_dict([(k, map_aggregate(v, fn)) for k, v in a.items()])
  20. elif isinstance(a, slice):
  21. result = slice(
  22. map_aggregate(a.start, fn),
  23. map_aggregate(a.stop, fn),
  24. map_aggregate(a.step, fn),
  25. )
  26. else:
  27. result = fn(a)
  28. return result
  29. __all__ = [
  30. "map_arg",
  31. "map_aggregate",
  32. ]