_aoti.pyi 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. from ctypes import c_void_p
  2. from typing import overload, Protocol
  3. from torch import Tensor
  4. # Defined in torch/csrc/inductor/aoti_runner/pybind.cpp
  5. # Tensor to AtenTensorHandle
  6. def unsafe_alloc_void_ptrs_from_tensors(tensors: list[Tensor]) -> list[c_void_p]: ...
  7. def unsafe_alloc_void_ptr_from_tensor(tensor: Tensor) -> c_void_p: ...
  8. # AtenTensorHandle to Tensor
  9. def alloc_tensors_by_stealing_from_void_ptrs(
  10. handles: list[c_void_p],
  11. ) -> list[Tensor]: ...
  12. def alloc_tensor_by_stealing_from_void_ptr(
  13. handle: c_void_p,
  14. ) -> Tensor: ...
  15. class AOTIModelContainerRunner(Protocol):
  16. def run(
  17. self, inputs: list[Tensor], stream_handle: c_void_p = ...
  18. ) -> list[Tensor]: ...
  19. def get_call_spec(self) -> list[str]: ...
  20. def get_constant_names_to_original_fqns(self) -> dict[str, str]: ...
  21. def get_constant_names_to_dtypes(self) -> dict[str, int]: ...
  22. def extract_constants_map(self, use_inactive: bool) -> dict[str, Tensor]: ...
  23. def update_constant_buffer(
  24. self,
  25. tensor_map: dict[str, Tensor],
  26. use_inactive: bool,
  27. validate_full_updates: bool,
  28. user_managed: bool = ...,
  29. ) -> None: ...
  30. def swap_constant_buffer(self) -> None: ...
  31. def free_inactive_constant_buffer(self) -> None: ...
  32. class AOTIModelContainerRunnerCpu:
  33. def __init__(self, model_so_path: str, num_models: int) -> None: ...
  34. def run(
  35. self, inputs: list[Tensor], stream_handle: c_void_p = ...
  36. ) -> list[Tensor]: ...
  37. def get_call_spec(self) -> list[str]: ...
  38. def get_constant_names_to_original_fqns(self) -> dict[str, str]: ...
  39. def get_constant_names_to_dtypes(self) -> dict[str, int]: ...
  40. def extract_constants_map(self, use_inactive: bool) -> dict[str, Tensor]: ...
  41. def update_constant_buffer(
  42. self,
  43. tensor_map: dict[str, Tensor],
  44. use_inactive: bool,
  45. validate_full_updates: bool,
  46. user_managed: bool = ...,
  47. ) -> None: ...
  48. def swap_constant_buffer(self) -> None: ...
  49. def free_inactive_constant_buffer(self) -> None: ...
  50. class AOTIModelContainerRunnerCuda:
  51. @overload
  52. def __init__(self, model_so_path: str, num_models: int) -> None: ...
  53. @overload
  54. def __init__(
  55. self, model_so_path: str, num_models: int, device_str: str
  56. ) -> None: ...
  57. @overload
  58. def __init__(
  59. self, model_so_path: str, num_models: int, device_str: str, cubin_dir: str
  60. ) -> None: ...
  61. def run(
  62. self, inputs: list[Tensor], stream_handle: c_void_p = ...
  63. ) -> list[Tensor]: ...
  64. def get_call_spec(self) -> list[str]: ...
  65. def get_constant_names_to_original_fqns(self) -> dict[str, str]: ...
  66. def get_constant_names_to_dtypes(self) -> dict[str, int]: ...
  67. def extract_constants_map(self, use_inactive: bool) -> dict[str, Tensor]: ...
  68. def update_constant_buffer(
  69. self,
  70. tensor_map: dict[str, Tensor],
  71. use_inactive: bool,
  72. validate_full_updates: bool,
  73. user_managed: bool = ...,
  74. ) -> None: ...
  75. def swap_constant_buffer(self) -> None: ...
  76. def free_inactive_constant_buffer(self) -> None: ...
  77. class AOTIModelContainerRunnerXpu:
  78. @overload
  79. def __init__(self, model_so_path: str, num_models: int) -> None: ...
  80. @overload
  81. def __init__(
  82. self, model_so_path: str, num_models: int, device_str: str
  83. ) -> None: ...
  84. @overload
  85. def __init__(
  86. self, model_so_path: str, num_models: int, device_str: str, kernel_bin_dir: str
  87. ) -> None: ...
  88. def run(
  89. self, inputs: list[Tensor], stream_handle: c_void_p = ...
  90. ) -> list[Tensor]: ...
  91. def get_call_spec(self) -> list[str]: ...
  92. def get_constant_names_to_original_fqns(self) -> dict[str, str]: ...
  93. def get_constant_names_to_dtypes(self) -> dict[str, int]: ...
  94. def extract_constants_map(self, use_inactive: bool) -> dict[str, Tensor]: ...
  95. def update_constant_buffer(
  96. self,
  97. tensor_map: dict[str, Tensor],
  98. use_inactive: bool,
  99. validate_full_updates: bool,
  100. user_managed: bool = ...,
  101. ) -> None: ...
  102. def swap_constant_buffer(self) -> None: ...
  103. def free_inactive_constant_buffer(self) -> None: ...
  104. class AOTIModelContainerRunnerMps:
  105. def __init__(self, model_so_path: str, num_models: int) -> None: ...
  106. def run(
  107. self, inputs: list[Tensor], stream_handle: c_void_p = ...
  108. ) -> list[Tensor]: ...
  109. def get_call_spec(self) -> list[str]: ...
  110. def get_constant_names_to_original_fqns(self) -> dict[str, str]: ...
  111. def get_constant_names_to_dtypes(self) -> dict[str, int]: ...
  112. def extract_constants_map(self, use_inactive: bool) -> dict[str, Tensor]: ...
  113. def update_constant_buffer(
  114. self,
  115. tensor_map: dict[str, Tensor],
  116. use_inactive: bool,
  117. validate_full_updates: bool,
  118. user_managed: bool = ...,
  119. ) -> None: ...
  120. def swap_constant_buffer(self) -> None: ...
  121. def free_inactive_constant_buffer(self) -> None: ...
  122. # Defined in torch/csrc/inductor/aoti_package/pybind.cpp
  123. class AOTIModelPackageLoader:
  124. def __init__(
  125. self,
  126. model_package_path: str,
  127. model_name: str,
  128. run_single_threaded: bool,
  129. num_runners: int,
  130. device_index: int,
  131. ) -> None: ...
  132. def get_metadata(self) -> dict[str, str]: ...
  133. def run(
  134. self, inputs: list[Tensor], stream_handle: c_void_p = ...
  135. ) -> list[Tensor]: ...
  136. def boxed_run(
  137. self, inputs: list[Tensor], stream_handle: c_void_p = ...
  138. ) -> list[Tensor]: ...
  139. def get_call_spec(self) -> list[str]: ...
  140. def get_constant_fqns(self) -> list[str]: ...
  141. def load_constants(
  142. self,
  143. constants_map: dict[str, Tensor],
  144. use_inactive: bool,
  145. check_full_update: bool,
  146. user_managed: bool = ...,
  147. ) -> None: ...
  148. def update_constant_buffer(
  149. self,
  150. tensor_map: dict[str, Tensor],
  151. use_inactive: bool,
  152. validate_full_updates: bool,
  153. user_managed: bool = ...,
  154. ) -> None: ...