effects.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. from enum import Enum
  2. from typing import Optional
  3. import torch
  4. class EffectType(Enum):
  5. ORDERED = "Ordered"
  6. from torch._library.utils import RegistrationHandle
  7. # These classes do not have side effects as they just store quantization
  8. # params, so we dont need to mark them as ordered
  9. skip_classes = (
  10. "__torch__.torch.classes.quantized.Conv2dPackedParamsBase",
  11. "__torch__.torch.classes.quantized.Conv3dPackedParamsBase",
  12. "__torch__.torch.classes.quantized.EmbeddingPackedParamsBase",
  13. "__torch__.torch.classes.quantized.LinearPackedParamsBase",
  14. "__torch__.torch.classes.xnnpack.Conv2dOpContext",
  15. "__torch__.torch.classes.xnnpack.LinearOpContext",
  16. "__torch__.torch.classes.xnnpack.TransposeConv2dOpContext",
  17. )
  18. class EffectHolder:
  19. """A holder where one can register an effect impl to."""
  20. def __init__(self, qualname: str):
  21. self.qualname: str = qualname
  22. self._set_default_effect()
  23. def _set_default_effect(self) -> None:
  24. self._effect: Optional[EffectType] = None
  25. # If the op contains a ScriptObject input, we want to mark it as having effects
  26. namespace, opname = torch._library.utils.parse_namespace(self.qualname)
  27. split = opname.split(".")
  28. if len(split) > 1:
  29. assert len(split) == 2, (
  30. f"Tried to split {opname} based on '.' but found more than 1 '.'"
  31. )
  32. opname, overload = split
  33. else:
  34. overload = ""
  35. if namespace == "higher_order":
  36. return
  37. opname = f"{namespace}::{opname}"
  38. if torch._C._get_operation_overload(opname, overload) is not None:
  39. # Since we call this when destroying the library, sometimes the
  40. # schema will be gone already at that time.
  41. schema = torch._C._get_schema(opname, overload)
  42. for arg in schema.arguments:
  43. if isinstance(arg.type, torch.ClassType):
  44. type_str = arg.type.str() # pyrefly: ignore[missing-attribute]
  45. if type_str in skip_classes:
  46. continue
  47. self._effect = EffectType.ORDERED
  48. return
  49. @property
  50. def effect(self) -> Optional[EffectType]:
  51. return self._effect
  52. @effect.setter
  53. def effect(self, _):
  54. raise RuntimeError("Unable to directly set kernel.")
  55. def register(self, effect: Optional[EffectType]) -> RegistrationHandle:
  56. """Register an effect
  57. Returns a RegistrationHandle that one can use to de-register this
  58. effect.
  59. """
  60. self._effect = effect
  61. def deregister_effect():
  62. self._set_default_effect()
  63. handle = RegistrationHandle(deregister_effect)
  64. return handle