_autograd.pyi 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. # mypy: allow-untyped-defs
  2. from collections.abc import Callable
  3. from enum import Enum
  4. from typing import Any
  5. import torch
  6. from torch._C._profiler import (
  7. _ProfilerEvent,
  8. ActiveProfilerType,
  9. ProfilerActivity,
  10. ProfilerConfig,
  11. )
  12. # Defined in torch/csrc/autograd/init.cpp
  13. class DeviceType(Enum):
  14. CPU = ...
  15. CUDA = ...
  16. XPU = ...
  17. MKLDNN = ...
  18. OPENGL = ...
  19. OPENCL = ...
  20. IDEEP = ...
  21. HIP = ...
  22. FPGA = ...
  23. MAIA = ...
  24. XLA = ...
  25. MTIA = ...
  26. MPS = ...
  27. HPU = ...
  28. Meta = ...
  29. Vulkan = ...
  30. Metal = ...
  31. PrivateUse1 = ...
  32. class ProfilerEvent:
  33. def cpu_elapsed_us(self, other: ProfilerEvent) -> float: ...
  34. def cpu_memory_usage(self) -> int: ...
  35. def cuda_elapsed_us(self, other: ProfilerEvent) -> float: ...
  36. def privateuse1_elapsed_us(self, other: ProfilerEvent) -> float: ...
  37. def cuda_memory_usage(self) -> int: ...
  38. def device(self) -> int: ...
  39. def handle(self) -> int: ...
  40. def has_cuda(self) -> bool: ...
  41. def is_remote(self) -> bool: ...
  42. def kind(self) -> int: ...
  43. def name(self) -> str: ...
  44. def node_id(self) -> int: ...
  45. def sequence_nr(self) -> int: ...
  46. def shapes(self) -> list[list[int]]: ...
  47. def thread_id(self) -> int: ...
  48. def flops(self) -> float: ...
  49. def is_async(self) -> bool: ...
  50. class _KinetoEvent:
  51. def name(self) -> str: ...
  52. def overload_name(self) -> str: ...
  53. def device_index(self) -> int: ...
  54. def device_resource_id(self) -> int: ...
  55. def start_ns(self) -> int: ...
  56. def end_ns(self) -> int: ...
  57. def duration_ns(self) -> int: ...
  58. def is_async(self) -> bool: ...
  59. def linked_correlation_id(self) -> int: ...
  60. def shapes(self) -> list[list[int]]: ...
  61. def dtypes(self) -> list[str]: ...
  62. def concrete_inputs(self) -> list[Any]: ...
  63. def kwinputs(self) -> dict[str, Any]: ...
  64. def device_type(self) -> DeviceType: ...
  65. def start_thread_id(self) -> int: ...
  66. def end_thread_id(self) -> int: ...
  67. def correlation_id(self) -> int: ...
  68. def fwd_thread_id(self) -> int: ...
  69. def stack(self) -> list[str]: ...
  70. def scope(self) -> int: ...
  71. def sequence_nr(self) -> int: ...
  72. def flops(self) -> int: ...
  73. def cuda_elapsed_us(self) -> int: ...
  74. def privateuse1_elapsed_us(self) -> int: ...
  75. def is_user_annotation(self) -> bool: ...
  76. def is_hidden_event(self) -> bool: ...
  77. def metadata_json(self) -> str: ...
  78. class _ProfilerResult:
  79. def events(self) -> list[_KinetoEvent]: ...
  80. def legacy_events(self) -> list[list[ProfilerEvent]]: ...
  81. def save(self, path: str) -> None: ...
  82. def experimental_event_tree(self) -> list[_ProfilerEvent]: ...
  83. def trace_start_ns(self) -> int: ...
  84. class SavedTensor: ...
  85. def _enable_profiler(
  86. config: ProfilerConfig,
  87. activities: set[ProfilerActivity],
  88. ) -> None: ...
  89. def _prepare_profiler(
  90. config: ProfilerConfig,
  91. activities: set[ProfilerActivity],
  92. ) -> None: ...
  93. def _toggle_collection_dynamic(
  94. enable: bool,
  95. activities: set[ProfilerActivity],
  96. ) -> None: ...
  97. def _disable_profiler() -> _ProfilerResult: ...
  98. def _profiler_enabled() -> bool: ...
  99. def _add_metadata_json(key: str, value: str) -> None: ...
  100. def _kineto_step() -> None: ...
  101. def _get_current_graph_task_keep_graph() -> bool: ...
  102. def _get_sequence_nr() -> int: ...
  103. def kineto_available() -> bool: ...
  104. def _record_function_with_args_enter(name: str, *args) -> torch.Tensor: ...
  105. def _record_function_with_args_exit(handle: torch.Tensor) -> None: ...
  106. def _supported_activities() -> set[ProfilerActivity]: ...
  107. def _enable_record_function(enable: bool) -> None: ...
  108. def _set_empty_test_observer(is_global: bool, sampling_prob: float) -> None: ...
  109. def _push_saved_tensors_default_hooks(
  110. pack_hook: Callable[[torch.Tensor], Any],
  111. unpack_hook: Callable[[Any], torch.Tensor],
  112. ) -> None: ...
  113. def _pop_saved_tensors_default_hooks() -> None: ...
  114. def _top_saved_tensors_default_hooks(
  115. ignore_is_tracing: bool,
  116. ) -> tuple[Callable[[torch.Tensor], Any], Callable[[Any], torch.Tensor]]: ...
  117. def _unsafe_set_version_counter(
  118. t: tuple[torch.Tensor, ...], prev_version: tuple[int, ...]
  119. ) -> None: ...
  120. def _enable_profiler_legacy(config: ProfilerConfig) -> None: ...
  121. def _disable_profiler_legacy() -> list[list[ProfilerEvent]]: ...
  122. def _profiler_type() -> ActiveProfilerType: ...
  123. def _saved_tensors_hooks_enable() -> None: ...
  124. def _saved_tensors_hooks_disable(message: str, fail_if_non_empty=True) -> None: ...
  125. def _saved_tensors_hooks_get_disabled_error_message() -> str | None: ...
  126. def _saved_tensors_hooks_set_tracing(is_tracing: bool) -> bool: ...
  127. class CreationMeta(Enum):
  128. DEFAULT = ...
  129. IN_CUSTOM_FUNCTION = ...
  130. MULTI_OUTPUT_NODE = ...
  131. NO_GRAD_MODE = ...
  132. INFERENCE_MODE = ...
  133. def _set_creation_meta(t: torch.Tensor, creation_meta: CreationMeta) -> None: ...
  134. def _get_creation_meta(t: torch.Tensor) -> CreationMeta: ...