_script.pyi 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. # mypy: allow-untyped-defs
  2. # mypy: disable-error-code="type-arg"
  3. from typing import Any, Callable, NamedTuple, overload, TypeVar
  4. from typing_extensions import Never, TypeAlias
  5. from _typeshed import Incomplete
  6. import torch
  7. from torch._classes import classes as classes
  8. from torch._jit_internal import _qualified_name as _qualified_name
  9. from torch.jit._builtins import _register_builtin as _register_builtin
  10. from torch.jit._fuser import (
  11. _graph_for as _graph_for,
  12. _script_method_graph_for as _script_method_graph_for,
  13. )
  14. from torch.jit._monkeytype_config import (
  15. JitTypeTraceConfig as JitTypeTraceConfig,
  16. JitTypeTraceStore as JitTypeTraceStore,
  17. monkeytype_trace as monkeytype_trace,
  18. )
  19. from torch.jit._recursive import (
  20. _compile_and_register_class as _compile_and_register_class,
  21. infer_methods_to_compile as infer_methods_to_compile,
  22. ScriptMethodStub as ScriptMethodStub,
  23. wrap_cpp_module as wrap_cpp_module,
  24. )
  25. from torch.jit._serialization import validate_map_location as validate_map_location
  26. from torch.jit._state import (
  27. _enabled as _enabled,
  28. _set_jit_function_cache as _set_jit_function_cache,
  29. _set_jit_overload_cache as _set_jit_overload_cache,
  30. _try_get_jit_cached_function as _try_get_jit_cached_function,
  31. _try_get_jit_cached_overloads as _try_get_jit_cached_overloads,
  32. )
  33. from torch.jit.frontend import (
  34. get_default_args as get_default_args,
  35. get_jit_class_def as get_jit_class_def,
  36. get_jit_def as get_jit_def,
  37. )
  38. from torch.nn import Module as Module
  39. from torch.overrides import (
  40. has_torch_function as has_torch_function,
  41. has_torch_function_unary as has_torch_function_unary,
  42. has_torch_function_variadic as has_torch_function_variadic,
  43. )
  44. from torch.package import (
  45. PackageExporter as PackageExporter,
  46. PackageImporter as PackageImporter,
  47. )
  48. from torch.utils import set_module as set_module
  49. ScriptFunction = torch._C.ScriptFunction
  50. type_trace_db: JitTypeTraceStore
  51. # Defined in torch/csrc/jit/python/script_init.cpp
  52. ResolutionCallback: TypeAlias = Callable[[str], Callable[..., Any]]
  53. _ClassVar = TypeVar("_ClassVar", bound=type)
  54. def _reduce(cls) -> None: ...
  55. class Attribute(NamedTuple):
  56. value: Incomplete
  57. type: Incomplete
  58. def _get_type_trace_db(): ...
  59. def _get_function_from_type(cls, name): ...
  60. def _is_new_style_class(cls): ...
  61. class OrderedDictWrapper:
  62. _c: Incomplete
  63. def __init__(self, _c) -> None: ...
  64. def keys(self): ...
  65. def values(self): ...
  66. def __len__(self) -> int: ...
  67. def __delitem__(self, k) -> None: ...
  68. def items(self): ...
  69. def __setitem__(self, k, v) -> None: ...
  70. def __contains__(self, k) -> bool: ...
  71. def __getitem__(self, k): ...
  72. class OrderedModuleDict(OrderedDictWrapper):
  73. _python_modules: Incomplete
  74. def __init__(self, module, python_dict) -> None: ...
  75. def items(self): ...
  76. def __contains__(self, k) -> bool: ...
  77. def __setitem__(self, k, v) -> None: ...
  78. def __getitem__(self, k): ...
  79. class ScriptMeta(type):
  80. def __init__(cls, name, bases, attrs) -> None: ...
  81. class _CachedForward:
  82. def __get__(self, obj, cls): ...
  83. class ScriptWarning(Warning): ...
  84. def script_method(fn): ...
  85. class ConstMap:
  86. const_mapping: Incomplete
  87. def __init__(self, const_mapping) -> None: ...
  88. def __getattr__(self, attr): ...
  89. def unpackage_script_module(
  90. importer: PackageImporter,
  91. script_module_id: str,
  92. ) -> torch.nn.Module: ...
  93. _magic_methods: Incomplete
  94. class RecursiveScriptClass:
  95. _c: Incomplete
  96. _props: Incomplete
  97. def __init__(self, cpp_class) -> None: ...
  98. def __getattr__(self, attr): ...
  99. def __setattr__(self, attr, value) -> None: ...
  100. def forward_magic_method(self, method_name, *args, **kwargs): ...
  101. def __getstate__(self) -> None: ...
  102. def __iadd__(self, other): ...
  103. def method_template(self, *args, **kwargs): ...
  104. class ScriptModule(Module, metaclass=ScriptMeta):
  105. __jit_unused_properties__: Incomplete
  106. def __init__(self) -> None: ...
  107. forward: Callable[..., Any]
  108. def __getattr__(self, attr): ...
  109. def __setattr__(self, attr, value) -> None: ...
  110. def define(self, src): ...
  111. def _replicate_for_data_parallel(self): ...
  112. def __reduce_package__(self, exporter: PackageExporter): ...
  113. # add __jit_unused_properties__
  114. @property
  115. def code(self) -> str: ...
  116. @property
  117. def code_with_constants(self) -> tuple[str, ConstMap]: ...
  118. @property
  119. def graph(self) -> torch.Graph: ...
  120. @property
  121. def inlined_graph(self) -> torch.Graph: ...
  122. @property
  123. def original_name(self) -> str: ...
  124. class RecursiveScriptModule(ScriptModule):
  125. _disable_script_meta: bool
  126. _c: Incomplete
  127. def __init__(self, cpp_module) -> None: ...
  128. @staticmethod
  129. def _construct(cpp_module, init_fn): ...
  130. @staticmethod
  131. def _finalize_scriptmodule(script_module) -> None: ...
  132. _concrete_type: Incomplete
  133. _modules: Incomplete
  134. _parameters: Incomplete
  135. _buffers: Incomplete
  136. __dict__: Incomplete
  137. def _reconstruct(self, cpp_module) -> None: ...
  138. def save(self, f, **kwargs): ...
  139. def _save_for_lite_interpreter(self, *args, **kwargs): ...
  140. def _save_to_buffer_for_lite_interpreter(self, *args, **kwargs): ...
  141. def save_to_buffer(self, *args, **kwargs): ...
  142. def get_debug_state(self, *args, **kwargs): ...
  143. def extra_repr(self): ...
  144. def graph_for(self, *args, **kwargs): ...
  145. def define(self, src) -> None: ...
  146. def __getattr__(self, attr): ...
  147. def __setattr__(self, attr, value) -> None: ...
  148. def __copy__(self): ...
  149. def __deepcopy__(self, memo): ...
  150. def forward_magic_method(self, method_name, *args, **kwargs): ...
  151. def __iter__(self): ...
  152. def __getitem__(self, idx): ...
  153. def __len__(self) -> int: ...
  154. def __contains__(self, key) -> bool: ...
  155. def __dir__(self): ...
  156. def __bool__(self) -> bool: ...
  157. def _replicate_for_data_parallel(self): ...
  158. def _get_methods(cls): ...
  159. _compiled_methods_allowlist: Incomplete
  160. def _make_fail(name): ...
  161. def call_prepare_scriptable_func_impl(obj, memo): ...
  162. def call_prepare_scriptable_func(obj): ...
  163. def create_script_dict(obj): ...
  164. def create_script_list(obj, type_hint: Incomplete | None = ...): ...
  165. @overload
  166. def script(
  167. obj: type[Module],
  168. optimize: bool | None = None,
  169. _frames_up: int = 0,
  170. _rcb: ResolutionCallback | None = None,
  171. example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = None,
  172. ) -> Never: ...
  173. @overload
  174. def script(
  175. obj: dict,
  176. optimize: bool | None = None,
  177. _frames_up: int = 0,
  178. _rcb: ResolutionCallback | None = None,
  179. example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = None,
  180. ) -> torch.ScriptDict: ...
  181. @overload
  182. def script(
  183. obj: list,
  184. optimize: bool | None = None,
  185. _frames_up: int = 0,
  186. _rcb: ResolutionCallback | None = None,
  187. example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = None,
  188. ) -> torch.ScriptList: ...
  189. @overload
  190. def script( # type: ignore[overload-overlap]
  191. obj: Module,
  192. optimize: bool | None = None,
  193. _frames_up: int = 0,
  194. _rcb: ResolutionCallback | None = None,
  195. example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = None,
  196. ) -> RecursiveScriptModule: ...
  197. @overload
  198. def script( # type: ignore[overload-overlap]
  199. obj: _ClassVar,
  200. optimize: bool | None = None,
  201. _frames_up: int = 0,
  202. _rcb: ResolutionCallback | None = None,
  203. example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = None,
  204. ) -> _ClassVar: ...
  205. @overload
  206. def script(
  207. obj: Callable,
  208. optimize: bool | None = None,
  209. _frames_up: int = 0,
  210. _rcb: ResolutionCallback | None = None,
  211. example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = None,
  212. ) -> ScriptFunction: ...
  213. @overload
  214. def script(
  215. obj: Any,
  216. optimize: bool | None = None,
  217. _frames_up: int = 0,
  218. _rcb: ResolutionCallback | None = None,
  219. example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = None,
  220. ) -> RecursiveScriptClass: ...
  221. @overload
  222. def script(
  223. obj,
  224. optimize: Incomplete | None = ...,
  225. _frames_up: int = ...,
  226. _rcb: Incomplete | None = ...,
  227. example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = ...,
  228. ): ...
  229. def _check_overload_defaults(impl_defaults, overload_defaults, loc) -> None: ...
  230. def _compile_function_with_overload(overload_fn, qual_name, impl_fn): ...
  231. def _get_overloads(obj): ...
  232. def _check_directly_compile_overloaded(obj) -> None: ...
  233. def interface(obj): ...
  234. def _recursive_compile_class(obj, loc): ...
  235. CompilationUnit: Incomplete
  236. def pad(s: str, padding: int, offset: int = ..., char: str = ...): ...
  237. class _ScriptProfileColumn:
  238. header: Incomplete
  239. alignment: Incomplete
  240. offset: Incomplete
  241. rows: Incomplete
  242. def __init__(
  243. self,
  244. header: str,
  245. alignment: int = ...,
  246. offset: int = ...,
  247. ) -> None: ...
  248. def add_row(self, lineno: int, value: Any): ...
  249. def materialize(self): ...
  250. class _ScriptProfileTable:
  251. cols: Incomplete
  252. source_range: Incomplete
  253. def __init__(
  254. self,
  255. cols: list[_ScriptProfileColumn],
  256. source_range: list[int],
  257. ) -> None: ...
  258. def dump_string(self): ...
  259. class _ScriptProfile:
  260. profile: Incomplete
  261. def __init__(self) -> None: ...
  262. def enable(self) -> None: ...
  263. def disable(self) -> None: ...
  264. def dump_string(self) -> str: ...
  265. def dump(self) -> None: ...
  266. def _unwrap_optional(x): ...