local.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. from __future__ import annotations
  2. import threading
  3. from contextlib import contextmanager
  4. from typing import TYPE_CHECKING
  5. if TYPE_CHECKING:
  6. from collections.abc import Iterator
  7. # Simple dynamic scoping implementation. The name "parametrize" comes
  8. # from Racket.
  9. #
  10. # WARNING WARNING: LOOKING TO EDIT THIS FILE? Think carefully about
  11. # why you need to add a toggle to the global behavior of code
  12. # generation. The parameters here should really only be used
  13. # for "temporary" situations, where we need to temporarily change
  14. # the codegen in some cases because we cannot conveniently update
  15. # all call sites, and are slated to be eliminated once all call
  16. # sites are eliminated. If you don't have a plan for how to get there,
  17. # DON'T add a new entry here.
  18. class Locals(threading.local):
  19. use_const_ref_for_mutable_tensors: bool | None = None
  20. use_ilistref_for_tensor_lists: bool | None = None
  21. _locals = Locals()
  22. def use_const_ref_for_mutable_tensors() -> bool:
  23. assert _locals.use_const_ref_for_mutable_tensors is not None, (
  24. "need to initialize local.use_const_ref_for_mutable_tensors with "
  25. "local.parametrize"
  26. )
  27. return _locals.use_const_ref_for_mutable_tensors
  28. def use_ilistref_for_tensor_lists() -> bool:
  29. assert _locals.use_ilistref_for_tensor_lists is not None, (
  30. "need to initialize local.use_ilistref_for_tensor_lists with local.parametrize"
  31. )
  32. return _locals.use_ilistref_for_tensor_lists
  33. @contextmanager
  34. def parametrize(
  35. *, use_const_ref_for_mutable_tensors: bool, use_ilistref_for_tensor_lists: bool
  36. ) -> Iterator[None]:
  37. old_use_const_ref_for_mutable_tensors = _locals.use_const_ref_for_mutable_tensors
  38. old_use_ilistref_for_tensor_lists = _locals.use_ilistref_for_tensor_lists
  39. try:
  40. _locals.use_const_ref_for_mutable_tensors = use_const_ref_for_mutable_tensors
  41. _locals.use_ilistref_for_tensor_lists = use_ilistref_for_tensor_lists
  42. yield
  43. finally:
  44. _locals.use_const_ref_for_mutable_tensors = (
  45. old_use_const_ref_for_mutable_tensors
  46. )
  47. _locals.use_ilistref_for_tensor_lists = old_use_ilistref_for_tensor_lists