lazy.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. import collections
  2. import functools
  3. import inspect
  4. from typing import Any, Callable, final, Optional, Union
  5. from typing_extensions import Self
  6. from ..utils import is_function_or_wrapper
  7. from .base import VariableTracker
  8. from .tensor import SymNodeVariable
  9. class LazyCache:
  10. """Container to cache the real VariableTracker"""
  11. def __init__(self, value: Any, source: Any) -> None:
  12. if not isinstance(value, LazySymNodeFormatString):
  13. assert source
  14. self.value = value
  15. self.source = source
  16. self.name_hint: Optional[str] = None
  17. self.vt: Optional[VariableTracker] = None
  18. def realize(self) -> None:
  19. assert self.vt is None
  20. from ..symbolic_convert import InstructionTranslator
  21. from . import builder
  22. tx = InstructionTranslator.current_tx()
  23. if isinstance(self.value, LazySymNodeFormatString):
  24. self.vt = builder.SourcelessBuilder.create(tx, self.value)
  25. else:
  26. self.vt = builder.VariableBuilder(tx, self.source)(self.value)
  27. if self.name_hint is not None:
  28. self.vt.set_name_hint(self.name_hint)
  29. del self.value
  30. del self.source
  31. del self.name_hint
  32. @final
  33. class LazyVariableTracker(VariableTracker):
  34. """
  35. A structure that defers the creation of the actual VariableTracker
  36. for a given underlying value until it is accessed.
  37. The `realize` function invokes VariableTracker.build() to produce the real object.
  38. Once a LazyVariableTracker has been realized, internal bookkeeping will
  39. prevent double realization.
  40. This object should be utilized for processing containers, or objects that
  41. reference other objects where we may not want to take on creating all the
  42. VariableTrackers right away.
  43. """
  44. _nonvar_fields = {"_cache", *VariableTracker._nonvar_fields}
  45. @staticmethod
  46. def create(value: Any, source: Any, **options: Any) -> "LazyVariableTracker":
  47. return LazyVariableTracker(LazyCache(value, source), source=source, **options)
  48. def __init__(self, _cache: LazyCache, **kwargs: Any) -> None:
  49. assert isinstance(_cache, LazyCache)
  50. super().__init__(**kwargs)
  51. self._cache = _cache
  52. def realize(self) -> VariableTracker:
  53. """Force construction of the real VariableTracker"""
  54. if self._cache.vt is None:
  55. self._cache.realize()
  56. assert self._cache.vt is not None
  57. return self._cache.vt
  58. def unwrap(self) -> Union[VariableTracker, Self]:
  59. """Return the real VariableTracker if it already exists"""
  60. if self.is_realized():
  61. assert self._cache.vt is not None
  62. return self._cache.vt
  63. return self
  64. def is_realized(self) -> bool:
  65. return self._cache.vt is not None
  66. def clone(self, **kwargs: Any) -> VariableTracker:
  67. assert kwargs.get("_cache", self._cache) is self._cache
  68. if kwargs.get("source", self.source) is not self.source:
  69. self.realize()
  70. return VariableTracker.clone(self.unwrap(), **kwargs)
  71. def peek_type(self) -> type[Any]:
  72. assert not self.is_realized()
  73. return type(self._cache.value)
  74. def peek_value(self) -> Any:
  75. assert not self.is_realized()
  76. return self._cache.value
  77. def set_name_hint(self, name: str) -> None:
  78. if self.is_realized():
  79. self._cache.vt.set_name_hint(name) # type: ignore[union-attr]
  80. else:
  81. self._cache.name_hint = name
  82. def __str__(self) -> str:
  83. if self.is_realized():
  84. return repr(self.unwrap())
  85. return super().__repr__()
  86. def __getattr__(self, item: str) -> Any:
  87. return getattr(self.realize(), item)
  88. # most methods are auto-generated below, these are the ones we want to exclude
  89. visit = VariableTracker.visit # type: ignore[assignment]
  90. __repr__ = __str__
  91. @classmethod
  92. def realize_all(
  93. cls,
  94. value: Any,
  95. cache: Optional[dict[int, tuple[Any, Any]]] = None,
  96. ) -> Any:
  97. """
  98. Walk an object and realize all LazyVariableTrackers inside it.
  99. """
  100. if cache is None:
  101. cache = {}
  102. idx = id(value)
  103. if idx in cache:
  104. return cache[idx][0]
  105. value_cls = type(value)
  106. if issubclass(value_cls, LazyVariableTracker):
  107. result = cls.realize_all(value.realize(), cache)
  108. elif issubclass(value_cls, VariableTracker):
  109. # update value in-place
  110. result = value
  111. value_dict = value.__dict__
  112. nonvars = value._nonvar_fields
  113. for key in value_dict:
  114. if key not in nonvars:
  115. value_dict[key] = cls.realize_all(value_dict[key], cache)
  116. elif value_cls is list:
  117. result = [cls.realize_all(v, cache) for v in value]
  118. elif value_cls is tuple:
  119. result = tuple(cls.realize_all(v, cache) for v in value)
  120. elif value_cls in (dict, collections.OrderedDict):
  121. result = {k: cls.realize_all(v, cache) for k, v in list(value.items())}
  122. else:
  123. result = value
  124. # save `value` to keep it alive and ensure id() isn't reused
  125. cache[idx] = (result, value)
  126. return result
  127. def is_hashable(self) -> bool:
  128. # Checks that the underlying value is hashable without realizing the VT.
  129. # This is used by ConstDictVariable tracker to find if the key LazyVT
  130. # can be hashed.
  131. def _helper(value: Any) -> bool:
  132. # TODO: Add support for more types
  133. return (
  134. inspect.isbuiltin(value)
  135. or issubclass(type(value), type)
  136. or is_function_or_wrapper(value)
  137. )
  138. assert not self.is_realized()
  139. value = self._cache.value
  140. if isinstance(value, tuple):
  141. return all(_helper(v) for v in value)
  142. return _helper(value)
  143. def original_value(self) -> Any:
  144. # Returns the value without realizing the VT.
  145. assert not self.is_realized()
  146. return self._cache.value
  147. def original_source(self) -> Any:
  148. # Returns the source without realizing the VT.
  149. assert not self.is_realized()
  150. return self._cache.source
  151. class LazySymNodeFormatString:
  152. def __init__(
  153. self, sym_node_variable: SymNodeVariable, fmt_spec_var: VariableTracker
  154. ) -> None:
  155. from .constant import ConstantVariable
  156. self.sym_node_var = sym_node_variable
  157. self.fmt_var = ConstantVariable.create(
  158. "{:" + fmt_spec_var.as_python_constant() + "}"
  159. )
  160. def __repr__(self) -> str:
  161. return str.format(
  162. self.fmt_var.as_python_constant(),
  163. str(self.sym_node_var.evaluate_expr()),
  164. )
  165. def _create_realize_and_forward(
  166. name: str,
  167. ) -> Callable[[LazyVariableTracker, Any, Any], Any]:
  168. @functools.wraps(getattr(VariableTracker, name))
  169. def realize_and_forward(
  170. self: LazyVariableTracker, *args: Any, **kwargs: Any
  171. ) -> Any:
  172. return getattr(self.realize(), name)(*args, **kwargs)
  173. return realize_and_forward
  174. def _populate() -> None:
  175. for name, value in VariableTracker.__dict__.items():
  176. if name not in LazyVariableTracker.__dict__:
  177. if callable(value):
  178. setattr(LazyVariableTracker, name, _create_realize_and_forward(name))
  179. _populate()