custom_ops.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. # mypy: allow-untyped-defs
  2. import importlib
  3. import torch
  4. lib = torch.library.Library("export", "FRAGMENT") # noqa: TOR901
  5. lib.define(
  6. "access_subclass_inner_tensor(Tensor src_subclass_tensor, str attr) -> Tensor"
  7. )
  8. @torch.library.impl(lib, "access_subclass_inner_tensor", "Autograd")
  9. # When running under torch.inference_mode(), we seem to skip AUtograd key
  10. # so we should desugar this op as soon as we start tracing to post-dispatch.
  11. @torch.library.impl(lib, "access_subclass_inner_tensor", "Python")
  12. def _access_subclass_inner_tensor(
  13. src_subclass_tensor: torch.Tensor, attr: str
  14. ) -> torch.Tensor:
  15. from torch.utils._python_dispatch import is_traceable_wrapper_subclass
  16. assert is_traceable_wrapper_subclass(src_subclass_tensor)
  17. val = getattr(src_subclass_tensor, attr, None)
  18. if val is None or not isinstance(val, torch.Tensor):
  19. raise RuntimeError(
  20. f"Attribute {attr} is not a tensor or doesn't exist in {src_subclass_tensor}"
  21. )
  22. return val
  23. def _call_custom_autograd_function_in_pre_dispatch(function_cls_name, *args, **kwargs):
  24. """
  25. Import a custom autograd function by string name and call it. This is pretty bad
  26. because:
  27. 1) There is no schema
  28. Ideally we should automatically wrap custom autograd functions with a custom op, but
  29. that is too much work because we need to schematize custom autograd functions. For now,
  30. we just hackily put it in the IR.
  31. """
  32. # Parse module and class name
  33. module_name, class_name = function_cls_name.rsplit(".", 1)
  34. # Import the module and get the class
  35. module = importlib.import_module(module_name)
  36. function_cls = getattr(module, class_name)
  37. assert hasattr(function_cls, "apply")
  38. return function_cls.apply(*args, **kwargs)