_script.pyi 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297
  1. # mypy: allow-untyped-defs
  2. # mypy: disable-error-code="type-arg"
  3. from collections.abc import Callable
  4. from typing import Any, NamedTuple, overload, TypeAlias, TypeVar
  5. from typing_extensions import Never
  6. from _typeshed import Incomplete
  7. import torch
  8. from torch._classes import classes as classes
  9. from torch._jit_internal import _qualified_name as _qualified_name
  10. from torch.jit._builtins import _register_builtin as _register_builtin
  11. from torch.jit._fuser import (
  12. _graph_for as _graph_for,
  13. _script_method_graph_for as _script_method_graph_for,
  14. )
  15. from torch.jit._monkeytype_config import (
  16. JitTypeTraceConfig as JitTypeTraceConfig,
  17. JitTypeTraceStore as JitTypeTraceStore,
  18. monkeytype_trace as monkeytype_trace,
  19. )
  20. from torch.jit._recursive import (
  21. _compile_and_register_class as _compile_and_register_class,
  22. infer_methods_to_compile as infer_methods_to_compile,
  23. ScriptMethodStub as ScriptMethodStub,
  24. wrap_cpp_module as wrap_cpp_module,
  25. )
  26. from torch.jit._serialization import validate_map_location as validate_map_location
  27. from torch.jit._state import (
  28. _enabled as _enabled,
  29. _set_jit_function_cache as _set_jit_function_cache,
  30. _set_jit_overload_cache as _set_jit_overload_cache,
  31. _try_get_jit_cached_function as _try_get_jit_cached_function,
  32. _try_get_jit_cached_overloads as _try_get_jit_cached_overloads,
  33. )
  34. from torch.jit.frontend import (
  35. get_default_args as get_default_args,
  36. get_jit_class_def as get_jit_class_def,
  37. get_jit_def as get_jit_def,
  38. )
  39. from torch.nn import Module as Module
  40. from torch.overrides import (
  41. has_torch_function as has_torch_function,
  42. has_torch_function_unary as has_torch_function_unary,
  43. has_torch_function_variadic as has_torch_function_variadic,
  44. )
  45. from torch.package import (
  46. PackageExporter as PackageExporter,
  47. PackageImporter as PackageImporter,
  48. )
  49. from torch.utils import set_module as set_module
  50. ScriptFunction = torch._C.ScriptFunction
  51. type_trace_db: JitTypeTraceStore
  52. # Defined in torch/csrc/jit/python/script_init.cpp
  53. ResolutionCallback: TypeAlias = Callable[[str], Callable[..., Any]]
  54. _ClassVar = TypeVar("_ClassVar", bound=type)
  55. def _reduce(cls) -> None: ...
  56. class Attribute(NamedTuple):
  57. value: Incomplete
  58. type: Incomplete
  59. def _get_type_trace_db(): ...
  60. def _get_function_from_type(cls, name): ...
  61. def _is_new_style_class(cls): ...
  62. class OrderedDictWrapper:
  63. _c: Incomplete
  64. def __init__(self, _c) -> None: ...
  65. def keys(self): ...
  66. def values(self): ...
  67. def __len__(self) -> int: ...
  68. def __delitem__(self, k) -> None: ...
  69. def items(self): ...
  70. def __setitem__(self, k, v) -> None: ...
  71. def __contains__(self, k) -> bool: ...
  72. def __getitem__(self, k): ...
  73. class OrderedModuleDict(OrderedDictWrapper):
  74. _python_modules: Incomplete
  75. def __init__(self, module, python_dict) -> None: ...
  76. def items(self): ...
  77. def __contains__(self, k) -> bool: ...
  78. def __setitem__(self, k, v) -> None: ...
  79. def __getitem__(self, k): ...
  80. class ScriptMeta(type):
  81. def __init__(cls, name, bases, attrs) -> None: ...
  82. class _CachedForward:
  83. def __get__(self, obj, cls): ...
  84. class ScriptWarning(Warning): ...
  85. def script_method(fn): ...
  86. class ConstMap:
  87. const_mapping: Incomplete
  88. def __init__(self, const_mapping) -> None: ...
  89. def __getattr__(self, attr): ...
  90. def unpackage_script_module(
  91. importer: PackageImporter,
  92. script_module_id: str,
  93. ) -> torch.nn.Module: ...
  94. _magic_methods: Incomplete
  95. class RecursiveScriptClass:
  96. _c: Incomplete
  97. _props: Incomplete
  98. def __init__(self, cpp_class) -> None: ...
  99. def __getattr__(self, attr): ...
  100. def __setattr__(self, attr, value) -> None: ...
  101. def forward_magic_method(self, method_name, *args, **kwargs): ...
  102. def __getstate__(self) -> None: ...
  103. def __iadd__(self, other): ...
  104. def method_template(self, *args, **kwargs): ...
  105. class ScriptModule(Module, metaclass=ScriptMeta):
  106. __jit_unused_properties__: Incomplete
  107. def __init__(self) -> None: ...
  108. forward: Callable[..., Any]
  109. def __getattr__(self, attr): ...
  110. def __setattr__(self, attr, value) -> None: ...
  111. def define(self, src): ...
  112. def _replicate_for_data_parallel(self): ...
  113. def __reduce_package__(self, exporter: PackageExporter): ...
  114. # add __jit_unused_properties__
  115. @property
  116. def code(self) -> str: ...
  117. @property
  118. def code_with_constants(self) -> tuple[str, ConstMap]: ...
  119. @property
  120. def graph(self) -> torch.Graph: ...
  121. @property
  122. def inlined_graph(self) -> torch.Graph: ...
  123. @property
  124. def original_name(self) -> str: ...
  125. class RecursiveScriptModule(ScriptModule):
  126. _disable_script_meta: bool
  127. _c: Incomplete
  128. def __init__(self, cpp_module) -> None: ...
  129. @staticmethod
  130. def _construct(cpp_module, init_fn): ...
  131. @staticmethod
  132. def _finalize_scriptmodule(script_module) -> None: ...
  133. _concrete_type: Incomplete
  134. _modules: Incomplete
  135. _parameters: Incomplete
  136. _buffers: Incomplete
  137. __dict__: Incomplete
  138. def _reconstruct(self, cpp_module) -> None: ...
  139. def save(self, f, **kwargs): ...
  140. def _save_for_lite_interpreter(self, *args, **kwargs): ...
  141. def _save_to_buffer_for_lite_interpreter(self, *args, **kwargs): ...
  142. def save_to_buffer(self, *args, **kwargs): ...
  143. def get_debug_state(self, *args, **kwargs): ...
  144. def extra_repr(self): ...
  145. def graph_for(self, *args, **kwargs): ...
  146. def define(self, src) -> None: ...
  147. def __getattr__(self, attr): ...
  148. def __setattr__(self, attr, value) -> None: ...
  149. def __copy__(self): ...
  150. def __deepcopy__(self, memo): ...
  151. def forward_magic_method(self, method_name, *args, **kwargs): ...
  152. def __iter__(self): ...
  153. def __getitem__(self, idx): ...
  154. def __len__(self) -> int: ...
  155. def __contains__(self, key) -> bool: ...
  156. def __dir__(self): ...
  157. def __bool__(self) -> bool: ...
  158. def _replicate_for_data_parallel(self): ...
  159. def _get_methods(cls): ...
  160. _compiled_methods_allowlist: Incomplete
  161. def _make_fail(name): ...
  162. def call_prepare_scriptable_func_impl(obj, memo): ...
  163. def call_prepare_scriptable_func(obj): ...
  164. def create_script_dict(obj): ...
  165. def create_script_list(obj, type_hint: Incomplete | None = ...): ...
  166. @overload
  167. def script(
  168. obj: type[Module],
  169. optimize: bool | None = None,
  170. _frames_up: int = 0,
  171. _rcb: ResolutionCallback | None = None,
  172. example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = None,
  173. ) -> Never: ...
  174. @overload
  175. def script(
  176. obj: dict,
  177. optimize: bool | None = None,
  178. _frames_up: int = 0,
  179. _rcb: ResolutionCallback | None = None,
  180. example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = None,
  181. ) -> torch.ScriptDict: ...
  182. @overload
  183. def script(
  184. obj: list,
  185. optimize: bool | None = None,
  186. _frames_up: int = 0,
  187. _rcb: ResolutionCallback | None = None,
  188. example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = None,
  189. ) -> torch.ScriptList: ...
  190. @overload
  191. def script( # type: ignore[overload-overlap]
  192. obj: Module,
  193. optimize: bool | None = None,
  194. _frames_up: int = 0,
  195. _rcb: ResolutionCallback | None = None,
  196. example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = None,
  197. ) -> RecursiveScriptModule: ...
  198. @overload
  199. def script( # type: ignore[overload-overlap]
  200. obj: _ClassVar,
  201. optimize: bool | None = None,
  202. _frames_up: int = 0,
  203. _rcb: ResolutionCallback | None = None,
  204. example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = None,
  205. ) -> _ClassVar: ...
  206. @overload
  207. def script(
  208. obj: Callable,
  209. optimize: bool | None = None,
  210. _frames_up: int = 0,
  211. _rcb: ResolutionCallback | None = None,
  212. example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = None,
  213. ) -> ScriptFunction: ...
  214. @overload
  215. def script(
  216. obj: Any,
  217. optimize: bool | None = None,
  218. _frames_up: int = 0,
  219. _rcb: ResolutionCallback | None = None,
  220. example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = None,
  221. ) -> RecursiveScriptClass: ...
  222. @overload
  223. def script(
  224. obj,
  225. optimize: Incomplete | None = ...,
  226. _frames_up: int = ...,
  227. _rcb: Incomplete | None = ...,
  228. example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = ...,
  229. ): ...
  230. def _check_overload_defaults(impl_defaults, overload_defaults, loc) -> None: ...
  231. def _compile_function_with_overload(overload_fn, qual_name, impl_fn): ...
  232. def _get_overloads(obj): ...
  233. def _check_directly_compile_overloaded(obj) -> None: ...
  234. def interface(obj): ...
  235. def _recursive_compile_class(obj, loc): ...
  236. CompilationUnit: Incomplete
  237. def pad(s: str, padding: int, offset: int = ..., char: str = ...): ...
  238. class _ScriptProfileColumn:
  239. header: Incomplete
  240. alignment: Incomplete
  241. offset: Incomplete
  242. rows: Incomplete
  243. def __init__(
  244. self,
  245. header: str,
  246. alignment: int = ...,
  247. offset: int = ...,
  248. ) -> None: ...
  249. def add_row(self, lineno: int, value: Any): ...
  250. def materialize(self): ...
  251. class _ScriptProfileTable:
  252. cols: Incomplete
  253. source_range: Incomplete
  254. def __init__(
  255. self,
  256. cols: list[_ScriptProfileColumn],
  257. source_range: list[int],
  258. ) -> None: ...
  259. def dump_string(self): ...
  260. class _ScriptProfile:
  261. profile: Incomplete
  262. def __init__(self) -> None: ...
  263. def enable(self) -> None: ...
  264. def disable(self) -> None: ...
  265. def dump_string(self) -> str: ...
  266. def dump(self) -> None: ...
  267. def _unwrap_optional(x): ...