__init__.py 104 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878287928802881288228832884288528862887288828892890289128922893289428952896289728982899290029012902290329042905290629072908290929102911291229132914291529162917291829192920292129222923292429252926292729282929293029312932293329342935293629372938293929402941294229432944294529462947294829492950295129522953295429552956295729582959296029612962296329642965296629672968296929702971297229732974297529762977297829792980
  1. """
  2. The torch package contains data structures for multi-dimensional
  3. tensors and defines mathematical operations over these tensors.
  4. Additionally, it provides many utilities for efficient serialization of
  5. Tensors and arbitrary types, and other useful utilities.
  6. It has a CUDA counterpart, that enables you to run your tensor computations
  7. on an NVIDIA GPU with compute capability >= 3.0.
  8. """
  9. # mypy: allow-untyped-defs
  10. import builtins
  11. import ctypes
  12. import functools
  13. import glob
  14. import importlib
  15. import inspect
  16. import math
  17. import os
  18. import platform
  19. import sys
  20. import textwrap
  21. import threading
  22. import warnings
  23. from collections.abc import Callable as _Callable
  24. from typing import (
  25. Any as _Any,
  26. get_origin as _get_origin,
  27. overload as _overload,
  28. TYPE_CHECKING,
  29. TypeVar as _TypeVar,
  30. )
  31. from typing_extensions import ParamSpec as _ParamSpec, TypeIs as _TypeIs
  32. # As a bunch of torch.packages internally still have this check
  33. # we need to keep this. @todo: Remove tests that rely on this check as
  34. # they are likely stale.
  35. def _running_with_deploy() -> builtins.bool:
  36. return False
  37. from torch._utils import (
  38. _functionalize_sync as _sync,
  39. _import_dotted_name,
  40. classproperty,
  41. )
  42. from torch._utils_internal import (
  43. get_file_path,
  44. prepare_multiprocessing_environment,
  45. profiler_allow_cudagraph_cupti_lazy_reinit_cuda12,
  46. USE_GLOBAL_DEPS,
  47. USE_RTLD_GLOBAL_WITH_LIBTORCH,
  48. )
  49. from torch.torch_version import __version__ as __version__
  50. if TYPE_CHECKING:
  51. from torch.types import Device, IntLikeType
  52. __all__ = [
  53. "BoolStorage",
  54. "BoolTensor",
  55. "ByteStorage",
  56. "ByteTensor",
  57. "CharStorage",
  58. "CharTensor",
  59. "DoubleStorage",
  60. "DoubleTensor",
  61. "FloatStorage",
  62. "FloatTensor",
  63. "GradScaler",
  64. "IntStorage",
  65. "IntTensor",
  66. "LongStorage",
  67. "LongTensor",
  68. "ShortStorage",
  69. "ShortTensor",
  70. "SymBool",
  71. "SymFloat",
  72. "SymInt",
  73. "Tensor",
  74. "TypedStorage",
  75. "UntypedStorage",
  76. "are_deterministic_algorithms_enabled",
  77. "autocast",
  78. "chunk",
  79. "compile",
  80. "cond",
  81. "enable_grad",
  82. "export",
  83. "get_default_device",
  84. "get_deterministic_debug_mode",
  85. "get_device_module",
  86. "get_float32_matmul_precision",
  87. "get_rng_state",
  88. "inference_mode",
  89. "initial_seed",
  90. "is_deterministic_algorithms_warn_only_enabled",
  91. "is_storage",
  92. "is_tensor",
  93. "is_warn_always_enabled",
  94. "load",
  95. "lobpcg",
  96. "manual_seed",
  97. "matmul",
  98. "no_grad",
  99. "rand",
  100. "randn",
  101. "save",
  102. "seed",
  103. "set_default_device",
  104. "set_default_tensor_type",
  105. "set_deterministic_debug_mode",
  106. "set_float32_matmul_precision",
  107. "set_printoptions",
  108. "set_rng_state",
  109. "set_warn_always",
  110. "split",
  111. "stack",
  112. "sym_float",
  113. "sym_fresh_size",
  114. "sym_int",
  115. "sym_ite",
  116. "sym_max",
  117. "sym_min",
  118. "sym_not",
  119. "sym_sum",
  120. "typename",
  121. "unravel_index",
  122. "use_deterministic_algorithms",
  123. "vmap",
  124. ]
  125. # Please keep this list sorted
  126. assert __all__ == sorted(__all__)
  127. ################################################################################
  128. # Load the extension module
  129. ################################################################################
  130. # If PyTorch was built against the ROCm runtime wheels, then there will be
  131. # a _rocm_init module and it will define an initialize() function which can
  132. # prepare ROCm for use. See general documentation on ROCm runtime wheels:
  133. # https://github.com/ROCm/TheRock/blob/main/docs/packaging/python_packaging.md
  134. # Since this module is only ever added to the wheel if built for such a
  135. # deployment, it is always safe to attempt.
  136. try:
  137. from . import _rocm_init # type: ignore[attr-defined]
  138. except ImportError:
  139. pass
  140. else:
  141. _rocm_init.initialize()
  142. del _rocm_init
  143. if sys.platform == "win32":
  144. def _load_dll_libraries() -> None:
  145. import sysconfig
  146. from torch.version import cuda as cuda_version
  147. pfiles_path = os.getenv("ProgramFiles", r"C:\Program Files")
  148. py_dll_path = os.path.join(sys.exec_prefix, "Library", "bin")
  149. th_dll_path = os.path.join(os.path.dirname(__file__), "lib")
  150. usebase_path = os.path.join(
  151. sysconfig.get_config_var("userbase"), "Library", "bin"
  152. )
  153. py_root_bin_path = os.path.join(sys.exec_prefix, "bin")
  154. # When users create a virtualenv that inherits the base environment,
  155. # we will need to add the corresponding library directory into
  156. # DLL search directories. Otherwise, it will rely on `PATH` which
  157. # is dependent on user settings.
  158. if sys.exec_prefix != sys.base_exec_prefix:
  159. base_py_dll_path = os.path.join(sys.base_exec_prefix, "Library", "bin")
  160. else:
  161. base_py_dll_path = ""
  162. dll_paths = [
  163. p
  164. for p in (
  165. th_dll_path,
  166. py_dll_path,
  167. base_py_dll_path,
  168. usebase_path,
  169. py_root_bin_path,
  170. )
  171. if os.path.exists(p)
  172. ]
  173. if not builtins.any(
  174. os.path.exists(os.path.join(p, "nvToolsExt64_1.dll")) for p in dll_paths
  175. ):
  176. nvtoolsext_dll_path = os.path.join(
  177. os.getenv(
  178. "NVTOOLSEXT_PATH",
  179. os.path.join(pfiles_path, "NVIDIA Corporation", "NvToolsExt"),
  180. ),
  181. "bin",
  182. "x64",
  183. )
  184. else:
  185. nvtoolsext_dll_path = ""
  186. if cuda_version and builtins.all(
  187. not glob.glob(os.path.join(p, "cudart64*.dll")) for p in dll_paths
  188. ):
  189. cuda_version_1 = cuda_version.replace(".", "_")
  190. cuda_path_var = "CUDA_PATH_V" + cuda_version_1
  191. default_path = os.path.join(
  192. pfiles_path, "NVIDIA GPU Computing Toolkit", "CUDA", f"v{cuda_version}"
  193. )
  194. cuda_path = os.path.join(os.getenv(cuda_path_var, default_path), "bin")
  195. else:
  196. cuda_path = ""
  197. dll_paths.extend(
  198. p for p in (nvtoolsext_dll_path, cuda_path) if os.path.exists(p)
  199. )
  200. kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True)
  201. with_load_library_flags = hasattr(kernel32, "AddDllDirectory")
  202. prev_error_mode = kernel32.SetErrorMode(0x0001)
  203. kernel32.LoadLibraryW.restype = ctypes.c_void_p
  204. if with_load_library_flags:
  205. kernel32.LoadLibraryExW.restype = ctypes.c_void_p
  206. for dll_path in dll_paths:
  207. os.add_dll_directory(dll_path)
  208. try:
  209. ctypes.CDLL("vcruntime140.dll")
  210. ctypes.CDLL("msvcp140.dll")
  211. if platform.machine() != "ARM64":
  212. ctypes.CDLL("vcruntime140_1.dll")
  213. except OSError:
  214. print(
  215. textwrap.dedent(
  216. """
  217. Microsoft Visual C++ Redistributable is not installed, this may lead to the DLL load failure.
  218. It can be downloaded at https://aka.ms/vs/17/release/vc_redist.x64.exe
  219. """
  220. ).strip()
  221. )
  222. dlls = glob.glob(os.path.join(th_dll_path, "*.dll"))
  223. path_patched = False
  224. for dll in dlls:
  225. is_loaded = False
  226. if with_load_library_flags:
  227. res = kernel32.LoadLibraryExW(dll, None, 0x00001100)
  228. last_error = ctypes.get_last_error()
  229. if res is None and last_error != 126:
  230. err = ctypes.WinError(last_error)
  231. err.strerror += (
  232. f' Error loading "{dll}" or one of its dependencies.'
  233. )
  234. raise err
  235. elif res is not None:
  236. is_loaded = True
  237. if not is_loaded:
  238. if not path_patched:
  239. os.environ["PATH"] = ";".join(dll_paths + [os.environ["PATH"]])
  240. path_patched = True
  241. res = kernel32.LoadLibraryW(dll)
  242. if res is None:
  243. err = ctypes.WinError(ctypes.get_last_error())
  244. err.strerror += (
  245. f' Error loading "{dll}" or one of its dependencies.'
  246. )
  247. raise err
  248. kernel32.SetErrorMode(prev_error_mode)
  249. _load_dll_libraries()
  250. del _load_dll_libraries
  251. def _get_cuda_dep_paths(path: str, lib_folder: str, lib_name: str) -> list[str]:
  252. # Libraries can either be in
  253. # path/nvidia/lib_folder/lib or
  254. # path/nvidia/cuXX/lib (since CUDA 13.0) or
  255. # path/lib_folder/lib
  256. from torch.version import cuda as cuda_version
  257. nvidia_lib_paths = glob.glob(
  258. os.path.join(path, "nvidia", lib_folder, "lib", lib_name)
  259. )
  260. if cuda_version is not None:
  261. maj_cuda_version = cuda_version.split(".")[0]
  262. nvidia_lib_paths += glob.glob(
  263. os.path.join(path, "nvidia", f"cu{maj_cuda_version}", "lib", lib_name)
  264. )
  265. lib_paths = glob.glob(os.path.join(path, lib_folder, "lib", lib_name))
  266. return nvidia_lib_paths + lib_paths
  267. def _preload_cuda_lib(lib_folder: str, lib_name: str, required: bool = True) -> None: # type: ignore[valid-type]
  268. """Preloads cuda library if it could not be found otherwise."""
  269. # Should only be called on Linux if default path resolution have failed
  270. assert platform.system() == "Linux", "Should only be called on Linux"
  271. lib_path = None
  272. for path in sys.path:
  273. candidate_lib_paths = _get_cuda_dep_paths(path, lib_folder, lib_name)
  274. if candidate_lib_paths:
  275. lib_path = candidate_lib_paths[0]
  276. break
  277. if not lib_path and required:
  278. raise ValueError(f"{lib_name} not found in the system path {sys.path}")
  279. if lib_path:
  280. ctypes.CDLL(lib_path)
  281. def _preload_cuda_deps(err: OSError | None = None) -> None:
  282. cuda_libs: list[tuple[str, str]] = [
  283. ("cublas", "libcublas.so.*[0-9]"),
  284. ("cudnn", "libcudnn.so.*[0-9]"),
  285. ("cuda_nvrtc", "libnvrtc.so.*[0-9]"),
  286. ("cuda_nvrtc", "libnvrtc-builtins.so.*[0-9]"),
  287. ("cuda_runtime", "libcudart.so.*[0-9]"),
  288. ("cuda_cupti", "libcupti.so.*[0-9]"),
  289. ("cufft", "libcufft.so.*[0-9]"),
  290. ("curand", "libcurand.so.*[0-9]"),
  291. ("nvjitlink", "libnvJitLink.so.*[0-9]"),
  292. ("cusparse", "libcusparse.so.*[0-9]"),
  293. ("cusparselt", "libcusparseLt.so.*[0-9]"),
  294. ("cusolver", "libcusolver.so.*[0-9]"),
  295. ("nccl", "libnccl.so.*[0-9]"),
  296. ("nvshmem", "libnvshmem_host.so.*[0-9]"),
  297. ("cufile", "libcufile.so.*[0-9]"),
  298. ]
  299. # If error is passed, re-raise it if it's not about one of the abovementioned
  300. # libraries
  301. if err is not None and not [
  302. lib for _, lib in cuda_libs if lib.split(".", 1)[0] in err.args[0]
  303. ]:
  304. raise err
  305. # Otherwise, try to preload dependencies from site-packages
  306. for lib_folder, lib_name in cuda_libs:
  307. _preload_cuda_lib(lib_folder, lib_name)
  308. # libnvToolsExt is Optional Dependency
  309. _preload_cuda_lib("nvtx", "libnvToolsExt.so.*[0-9]", required=False)
  310. # See Note [Global dependencies]
  311. def _load_global_deps() -> None:
  312. if platform.system() == "Windows":
  313. return
  314. # Determine the file extension based on the platform
  315. lib_ext = ".dylib" if platform.system() == "Darwin" else ".so"
  316. lib_name = f"libtorch_global_deps{lib_ext}"
  317. here = os.path.abspath(__file__)
  318. global_deps_lib_path = os.path.join(os.path.dirname(here), "lib", lib_name)
  319. try:
  320. ctypes.CDLL(global_deps_lib_path, mode=ctypes.RTLD_GLOBAL)
  321. # Workaround slim-wheel CUDA dependency bugs in cusparse and cudnn by preloading nvjitlink
  322. # and nvrtc. In CUDA-12.4+ cusparse depends on nvjitlink, but does not have rpath when
  323. # shipped as wheel, which results in OS picking wrong/older version of nvjitlink library
  324. # if `LD_LIBRARY_PATH` is defined, see https://github.com/pytorch/pytorch/issues/138460
  325. # Similar issue exist in cudnn that dynamically loads nvrtc, unaware of its relative path.
  326. # See https://github.com/pytorch/pytorch/issues/145580
  327. try:
  328. with open("/proc/self/maps") as f:
  329. _maps = f.read()
  330. # libtorch_global_deps.so always depends in cudart, check if its installed and loaded
  331. if "libcudart.so" not in _maps:
  332. return
  333. # If all above-mentioned conditions are met, preload CUDA dependencies
  334. _preload_cuda_deps()
  335. except Exception:
  336. pass
  337. except OSError as err:
  338. # Can happen for wheel with cuda libs as PYPI deps
  339. # As PyTorch is not purelib, but nvidia-*-cu12 is
  340. _preload_cuda_deps(err)
  341. ctypes.CDLL(global_deps_lib_path, mode=ctypes.RTLD_GLOBAL)
  342. if (USE_RTLD_GLOBAL_WITH_LIBTORCH or os.getenv("TORCH_USE_RTLD_GLOBAL")) and (
  343. platform.system() != "Windows"
  344. ):
  345. # Do it the hard way. You might want to load libtorch with RTLD_GLOBAL in a
  346. # few circumstances:
  347. #
  348. # 1. You're in a build environment (e.g., fbcode) where
  349. # libtorch_global_deps is not available, but you still need
  350. # to get mkl to link in with RTLD_GLOBAL or it will just
  351. # not work.
  352. #
  353. # 2. You're trying to run PyTorch under UBSAN and you need
  354. # to ensure that only one copy of libtorch is loaded, so
  355. # vptr checks work properly
  356. #
  357. # If you're using this setting, you must verify that all the libraries
  358. # you load consistently use the same libstdc++, or you may have
  359. # mysterious segfaults.
  360. #
  361. old_flags = sys.getdlopenflags()
  362. sys.setdlopenflags(os.RTLD_GLOBAL | os.RTLD_LAZY)
  363. from torch._C import * # noqa: F403
  364. sys.setdlopenflags(old_flags)
  365. del old_flags
  366. else:
  367. # Easy way. You want this most of the time, because it will prevent
  368. # C++ symbols from libtorch clobbering C++ symbols from other
  369. # libraries, leading to mysterious segfaults.
  370. #
  371. # If building in an environment where libtorch_global_deps isn't available
  372. # like parts of fbsource, but where RTLD_GLOBAL causes segfaults, you will
  373. # want USE_RTLD_GLOBAL_WITH_LIBTORCH = False and USE_GLOBAL_DEPS = False
  374. #
  375. # See Note [Global dependencies]
  376. if USE_GLOBAL_DEPS:
  377. _load_global_deps()
  378. from torch._C import * # noqa: F403
  379. class SymInt:
  380. """
  381. Like an int (including magic methods), but redirects all operations on the
  382. wrapped node. This is used in particular to symbolically record operations
  383. in the symbolic shape workflow.
  384. """
  385. def __init__(self, node):
  386. # This field MUST be named node; C++ binding code assumes that this
  387. # class has a field named node that stores SymNode
  388. self.node = node
  389. def __bool__(self):
  390. return builtins.bool(self != 0)
  391. def __int__(self):
  392. return self.node.int_()
  393. def __index__(self):
  394. return self.node.int_()
  395. # Magic methods installed by torch.fx.experimental.sym_node
  396. def __round__(self, ndigits=None):
  397. return self
  398. def __truediv__(self, other):
  399. if isinstance(other, (builtins.float, SymFloat)):
  400. return sym_float(self).__float_truediv__(other)
  401. if not isinstance(other, (builtins.int, SymInt)):
  402. return NotImplemented
  403. return self.__int_truediv__(other)
  404. def __rtruediv__(self, other):
  405. if isinstance(other, (builtins.float, SymFloat)):
  406. return sym_float(self).__rfloat_truediv__(other)
  407. if not isinstance(other, (builtins.int, SymInt)):
  408. return NotImplemented
  409. return self.__rint_truediv__(other)
  410. def __floordiv__(self, other):
  411. if isinstance(other, (builtins.float, SymFloat)):
  412. return sym_float(math.floor(sym_float(self) / other))
  413. if not isinstance(other, (builtins.int, SymInt)):
  414. return NotImplemented
  415. return self.__int_floordiv__(other)
  416. def __rfloordiv__(self, other):
  417. if isinstance(other, (builtins.float, SymFloat)):
  418. return sym_float(math.floor(other / sym_float(self)))
  419. if not isinstance(other, (builtins.int, SymInt)):
  420. return NotImplemented
  421. return self.__rint_floordiv__(other)
  422. # nb: complex is impossible to handle correctly lol, with
  423. # negative base and integral float need to diverge semantics and
  424. # just always return complex. Neener neener pretend this problem
  425. # doesn't exist
  426. def __pow__(self, other):
  427. if isinstance(other, (builtins.float, SymFloat)):
  428. return sym_float(self).__pow__(other)
  429. if not isinstance(other, (builtins.int, SymInt)):
  430. return NotImplemented
  431. # Guards! This guard is necessary because we need to know it to
  432. # determine the output type of this operation
  433. if other >= 0:
  434. return self.__pow_by_natural__(other)
  435. else:
  436. # Mercifully, when the exponent is negative, Python just promotes
  437. # to doubles and does a float pow:
  438. #
  439. # if (Py_SIZE(b) < 0 && c == NULL) {
  440. # /* if exponent is negative and there's no modulus:
  441. # return a float. This works because we know
  442. # that this calls float_pow() which converts its
  443. # arguments to double. */
  444. # Py_DECREF(a);
  445. # Py_DECREF(b);
  446. # return PyFloat_Type.tp_as_number->nb_power(v, w, x);
  447. # }
  448. return sym_float(self).__pow__(sym_float(other))
  449. def __rpow__(self, other):
  450. if isinstance(other, (builtins.float, SymFloat)):
  451. return sym_float(self).__rpow__(other)
  452. if not isinstance(other, (builtins.int, SymInt)):
  453. return NotImplemented
  454. if self >= 0: # self is exponent
  455. return self.__rpow_by_natural__(other)
  456. else:
  457. return sym_float(self).__rpow__(sym_float(other))
  458. def __eq__(self, other: object) -> builtins.bool:
  459. raise TypeError("type stub not overridden")
  460. def __lt__(self, other) -> builtins.bool:
  461. raise TypeError("type stub not overridden")
  462. def __gt__(self, other) -> builtins.bool:
  463. raise TypeError("type stub not overridden")
  464. def __le__(self, other) -> builtins.bool:
  465. raise TypeError("type stub not overridden")
  466. def __ge__(self, other) -> builtins.bool:
  467. raise TypeError("type stub not overridden")
  468. def __add__(self, other) -> "SymInt":
  469. raise TypeError("type stub not overridden")
  470. def __radd__(self, other) -> "SymInt":
  471. raise TypeError("type stub not overridden")
  472. def __rmul__(self, other) -> "SymInt":
  473. raise TypeError("type stub not overridden")
  474. def __mod__(self, other: "IntLikeType") -> "SymInt":
  475. raise TypeError("type stub not overridden")
  476. def __mul__(self, other) -> "SymInt":
  477. raise TypeError("type stub not overridden")
  478. def __pow_by_natural__(self, other) -> "SymInt":
  479. raise TypeError("type stub not overridden")
  480. def __rpow_by_natural__(self, other) -> "SymInt":
  481. raise TypeError("type stub not overridden")
  482. def __int_truediv__(self, other) -> "SymFloat":
  483. raise TypeError("type stub not overridden")
  484. def __rint_truediv__(self, other) -> "SymFloat":
  485. raise TypeError("type stub not overridden")
  486. def __int_floordiv__(self, other) -> "SymFloat":
  487. raise TypeError("type stub not overridden")
  488. def __rint_floordiv__(self, other) -> "SymFloat":
  489. raise TypeError("type stub not overridden")
  490. def __sym_max__(self, other):
  491. raise TypeError("type stub not overridden")
  492. def __sym_min__(self, other):
  493. raise TypeError("type stub not overridden")
  494. def __sym_float__(self):
  495. raise TypeError("type stub not overridden")
  496. def __neg__(self):
  497. raise TypeError("type stub not overridden")
  498. def __sub__(self, other: "IntLikeType") -> "SymInt":
  499. raise TypeError("type stub not overridden")
  500. def __rsub__(self, other: "IntLikeType") -> "SymInt":
  501. raise TypeError("type stub not overridden")
  502. def __and__(self, other) -> "SymInt":
  503. raise TypeError("type stub not overridden")
  504. def __or__(self, other) -> "SymInt":
  505. raise TypeError("type stub not overridden")
  506. def __repr__(self):
  507. return self.node._graph_repr()
  508. def _sympy_(self):
  509. return self.node.expr
  510. def __hash__(self) -> builtins.int:
  511. if self.node.is_nested_int():
  512. return hash(self.node.nested_int())
  513. else:
  514. # We could support constant SymInts as well, but not doing it for now
  515. raise TypeError("unhashable type: non-nested SymInt")
  516. # TODO: Force specialization
  517. # This can't be done because the TypeError here is load bearing
  518. # for einops
  519. # https://github.com/arogozhnikov/einops/blob/6181e1e95dc58c00a3143c1726da1c6ee0463164/einops/einops.py#L237
  520. # return hash(builtins.int(self))
  521. def as_integer_ratio(self) -> tuple["SymInt", builtins.int]:
  522. """Represent this int as an exact integer ratio"""
  523. return self, 1
  524. def bit_length(self) -> builtins.int:
  525. # TODO: A more relaxed guard is possible here, where you guard to
  526. # allow all integer quantities which would result in the same bit
  527. # length. We can also just make a dedicated Sympy function for
  528. # computing this quantity and represent it symbolically.
  529. return builtins.int(self).bit_length()
  530. def conjugate(self) -> "SymInt":
  531. return self
  532. class SymFloat:
  533. """
  534. Like a float (including magic methods), but redirects all operations on the
  535. wrapped node. This is used in particular to symbolically record operations
  536. in the symbolic shape workflow.
  537. """
  538. def __init__(self, node):
  539. # This field MUST be named node; C++ binding code assumes that this
  540. # class has a field named node that stores SymNode
  541. self.node = node
  542. def __truediv__(self, other):
  543. if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)):
  544. return NotImplemented
  545. return self.__float_truediv__(sym_float(other))
  546. def __rtruediv__(self, other):
  547. if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)):
  548. return NotImplemented
  549. return self.__rfloat_truediv__(sym_float(other))
  550. def __floordiv__(self, other):
  551. if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)):
  552. return NotImplemented
  553. return sym_float(math.floor(self / sym_float(other)))
  554. def __rfloordiv__(self, other):
  555. if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)):
  556. return NotImplemented
  557. return sym_float(math.floor(sym_float(other) / self))
  558. def __bool__(self):
  559. return self.node.bool_()
  560. def __float__(self):
  561. return self.node.guard_float("", 0)
  562. def __int__(self):
  563. return self.__trunc__().__int__()
  564. # Symbolic power does NOT work with negative base, this is to avoid
  565. # potential complex outputs
  566. def __pow__(self, other):
  567. if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)):
  568. return NotImplemented
  569. torch._check(self >= 0)
  570. return self.__float_pow__(other)
  571. def __rpow__(self, other):
  572. if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)):
  573. return NotImplemented
  574. torch._check(other >= 0)
  575. return self.__rfloat_pow__(other)
  576. # Magic methods installed by torch.fx.experimental.sym_node
  577. def __eq__(self, other: object) -> builtins.bool:
  578. raise TypeError("type stub not overridden")
  579. def __lt__(self, other) -> builtins.bool:
  580. raise TypeError("type stub not overridden")
  581. def __gt__(self, other) -> builtins.bool:
  582. raise TypeError("type stub not overridden")
  583. def __le__(self, other) -> builtins.bool:
  584. raise TypeError("type stub not overridden")
  585. def __ge__(self, other) -> builtins.bool:
  586. raise TypeError("type stub not overridden")
  587. def __float_pow__(self, other) -> "SymFloat":
  588. raise TypeError("type stub not overridden")
  589. def __rfloat_pow__(self, other) -> "SymFloat":
  590. raise TypeError("type stub not overridden")
  591. def __float_truediv__(self, other) -> "SymFloat":
  592. raise TypeError("type stub not overridden")
  593. def __rfloat_truediv__(self, other) -> "SymFloat":
  594. raise TypeError("type stub not overridden")
  595. def __trunc__(self):
  596. raise TypeError("type stub not overridden")
  597. def __sym_max__(self, other):
  598. raise TypeError("type stub not overridden")
  599. def __sym_min__(self, other):
  600. raise TypeError("type stub not overridden")
  601. def __sym_int__(self):
  602. raise TypeError("type stub not overridden")
  603. def is_integer(self):
  604. """Return True if the float is an integer."""
  605. raise TypeError("type stub not overridden")
  606. def as_integer_ratio(self) -> tuple[builtins.int, builtins.int]:
  607. """Represent this float as an exact integer ratio"""
  608. return builtins.float(self).as_integer_ratio()
  609. def __repr__(self):
  610. return self.node._graph_repr()
  611. def _sympy_(self):
  612. return self.node.expr
  613. def __hash__(self):
  614. return hash(builtins.float(self))
  615. def conjugate(self) -> "SymFloat":
  616. """Returns the complex conjugate of the float."""
  617. return self
  618. def hex(self) -> str:
  619. """Returns the hexadecimal representation of the float."""
  620. return self.node.guard_float("", 0).hex()
  621. class SymBool:
  622. """
  623. Like a bool (including magic methods), but redirects all operations on the
  624. wrapped node. This is used in particular to symbolically record operations
  625. in the symbolic shape workflow.
  626. Unlike regular bools, regular boolean operators will force extra guards instead
  627. of symbolically evaluate. Use the bitwise operators instead to handle this.
  628. """
  629. def __init__(self, node):
  630. # This field MUST be named node; C++ binding code assumes that this
  631. # class has a field named node that stores SymNode
  632. self.node = node
  633. def __bool__(self):
  634. return self.node.bool_()
  635. def __int__(self):
  636. return builtins.int(self.node.bool_())
  637. # Magic methods installed by torch.fx.experimental.sym_node
  638. def __and__(self, other) -> "SymBool":
  639. raise TypeError("type stub not overridden")
  640. def __or__(self, other) -> "SymBool":
  641. raise TypeError("type stub not overridden")
  642. # We very carefully define __sym_not__, and not a number of other
  643. # plausible alternatives:
  644. #
  645. # - We do not override __not__ because this is not a real magic
  646. # method; you cannot override the meaning of the not builtin in
  647. # Python. We use the name 'sym_not' to clarify that in user code you
  648. # cannot use the builtin not or operator.not_ or operator.__not__ and
  649. # hit this magic method; you must use our custom sym_not operator.
  650. #
  651. # - We do not override the __invert__ method because SymBool is
  652. # meant to be usable in situations where bool is expected. However,
  653. # bitwise negation ~a does the wrong thing with booleans (because
  654. # bool is a subclass of int, so ~1 = -2 which is not falseish.)
  655. # This would be a giant footgun, so we get around it by defining
  656. # our own operator. Note that bitwise and/or do the right thing,
  657. # so we reuse the conventional operators there for readability.
  658. #
  659. def __sym_not__(self) -> "SymBool":
  660. raise TypeError("type stub not overridden")
  661. def __sym_ite__(self, then_val, else_val):
  662. raise TypeError("type stub not overridden")
  663. def __eq__(self, other) -> builtins.bool:
  664. raise TypeError("type stub not overridden")
  665. def __repr__(self):
  666. return self.node._graph_repr()
  667. def _sympy_(self):
  668. return self.node.expr
  669. def __hash__(self):
  670. if self.node.is_constant():
  671. return hash(self.node.bool_())
  672. else:
  673. # Force specialization
  674. return hash(builtins.bool(self))
  675. def __sym_float__(self):
  676. """
  677. Provides a SymFloat representation (0.0 or 1.0) for this SymBool.
  678. Called by torch.sym_float() when casting SymBool to float.
  679. """
  680. from torch.fx.experimental.sym_node import wrap_node
  681. return wrap_node(self.node.sym_float())
  682. def sym_not(a):
  683. r"""SymInt-aware utility for logical negation.
  684. Args:
  685. a (SymBool or bool): Object to negate
  686. """
  687. import sympy
  688. if overrides.has_torch_function_unary(a):
  689. return overrides.handle_torch_function(sym_not, (a,), a)
  690. if hasattr(a, "__sym_not__"):
  691. return a.__sym_not__()
  692. if isinstance(a, sympy.Basic):
  693. return ~a # type: ignore[operator]
  694. return not a
  695. def sym_float(a):
  696. r"""SymInt-aware utility for float casting.
  697. Args:
  698. a (SymInt, SymFloat, or object): Object to cast
  699. """
  700. if overrides.has_torch_function_unary(a):
  701. return overrides.handle_torch_function(sym_float, (a,), a)
  702. if isinstance(a, SymFloat):
  703. return a
  704. elif hasattr(a, "__sym_float__"):
  705. return a.__sym_float__()
  706. return builtins.float(a) # type: ignore[operator]
  707. def sym_int(a):
  708. r"""SymInt-aware utility for int casting.
  709. Args:
  710. a (SymInt, SymFloat, or object): Object to cast
  711. """
  712. if overrides.has_torch_function_unary(a):
  713. return overrides.handle_torch_function(sym_int, (a,), a)
  714. if isinstance(a, SymInt):
  715. return a
  716. elif isinstance(a, SymFloat):
  717. return math.trunc(a)
  718. return builtins.int(a) # type: ignore[operator]
  719. def sym_max(a, b):
  720. """
  721. SymInt-aware utility for max which avoids branching on a < b.
  722. Unlike builtins.max(), this only works for int/float, and it always
  723. promotes to float if any argument is float (unlike builtins.max, which
  724. will faithfully preserve the type of the input argument).
  725. """
  726. if overrides.has_torch_function((a, b)):
  727. return overrides.handle_torch_function(sym_max, (a, b), a, b)
  728. if isinstance(a, (SymInt, SymFloat)):
  729. return a.__sym_max__(b)
  730. elif isinstance(b, (SymInt, SymFloat)):
  731. # Due to promotion semantics, this is operator is commutative:
  732. # max(1, 1.0) === max(1.0, 1) === 1.0
  733. return b.__sym_max__(a)
  734. # TODO: Probably can make bool work too, just lazy
  735. all_types, float_types = __all_and_float_types()
  736. assert isinstance(a, all_types), type(a)
  737. assert isinstance(b, all_types), type(b)
  738. if isinstance(a, float_types) or isinstance(b, float_types):
  739. return builtins.float(builtins.max(a, b)) # type: ignore[call-overload]
  740. else:
  741. return builtins.max(a, b) # type: ignore[call-overload]
  742. def __all_and_float_types() -> tuple[tuple[type, ...], tuple[type, ...]]:
  743. try:
  744. import numpy as np
  745. all_types: tuple[type, ...] = (
  746. np.integer,
  747. np.floating,
  748. builtins.int,
  749. builtins.float,
  750. )
  751. float_types: tuple[type, ...] = (np.floating, builtins.float)
  752. except ModuleNotFoundError:
  753. all_types = (builtins.int, builtins.float)
  754. float_types = (builtins.float,)
  755. return all_types, float_types
  756. def sym_min(a, b):
  757. """SymInt-aware utility for min()."""
  758. if overrides.has_torch_function((a, b)):
  759. return overrides.handle_torch_function(sym_min, (a, b), a, b)
  760. if isinstance(a, (SymInt, SymFloat)):
  761. return a.__sym_min__(b)
  762. elif isinstance(b, (SymInt, SymFloat)):
  763. return b.__sym_min__(a)
  764. all_types, float_types = __all_and_float_types()
  765. assert isinstance(a, all_types), type(a)
  766. assert isinstance(b, all_types), type(b)
  767. if isinstance(a, float_types) or isinstance(b, float_types):
  768. return builtins.float(builtins.min(a, b)) # type: ignore[call-overload]
  769. else:
  770. return builtins.min(a, b) # type: ignore[call-overload]
  771. def sym_sum(args):
  772. """
  773. N-ary add which is faster to compute for long lists than iterated binary
  774. addition. Only does something special for integers.
  775. """
  776. if overrides.has_torch_function(args):
  777. return overrides.handle_torch_function(sym_sum, args, args)
  778. found = None
  779. for a in args:
  780. if not isinstance(a, (SymInt, builtins.int)):
  781. return builtins.sum(args)
  782. if isinstance(a, SymInt):
  783. found = a.node
  784. if found is None:
  785. return builtins.sum(args)
  786. from torch.fx.experimental.sym_node import to_node, wrap_node
  787. return wrap_node(found.sym_sum(tuple(to_node(found, a) for a in args)))
  788. # Drop in replacement for math.sqrt, math.sin, math.cos etc
  789. def _get_sym_math_fn(name):
  790. def fn(a):
  791. if overrides.has_torch_function_unary(a):
  792. return overrides.handle_torch_function(fn, (a,), a)
  793. if isinstance(a, SymInt):
  794. a = torch.sym_float(a)
  795. if hasattr(a, f"__sym_{name}__"):
  796. return getattr(a, f"__sym_{name}__")()
  797. return getattr(math, name)(a)
  798. return fn
  799. __fn, __name, __sym_name = None, "", ""
  800. for __name in (
  801. "sqrt",
  802. "cos",
  803. "cosh",
  804. "sin",
  805. "sinh",
  806. "tan",
  807. "tanh",
  808. "asin",
  809. "acos",
  810. "atan",
  811. "log2",
  812. ):
  813. __sym_name = f"_sym_{__name}"
  814. __fn = _get_sym_math_fn(__name)
  815. __fn.__qualname__ = __fn.__name__ = __sym_name
  816. globals()[__sym_name] = __fn
  817. del __fn, __name, __sym_name, _get_sym_math_fn
  818. # Adding temporary shortcut
  819. sym_sqrt = globals()["_sym_sqrt"]
  820. __all__.append("sym_sqrt")
  821. def sym_ite(b, t, f):
  822. """SymInt-aware utility for ternary operator (``t if b else f``.)"""
  823. if overrides.has_torch_function((b, t, f)):
  824. return overrides.handle_torch_function(sym_ite, (b, t, f), b, t, f)
  825. assert isinstance(b, (SymBool, builtins.bool)) and type(t) is type(f)
  826. if isinstance(b, SymBool):
  827. return b.__sym_ite__(t, f)
  828. return t if b else f
  829. # Create a fresh unbacked int, from an (possibly unbacked int) expression.
  830. def sym_fresh_size(expr):
  831. return torch.tensor(expr).item()
  832. # Check to see if we can load C extensions, and if not provide some guidance
  833. # on what the problem might be.
  834. try:
  835. # _initExtension is chosen (arbitrarily) as a sentinel.
  836. from torch._C import _initExtension
  837. except ImportError:
  838. import torch._C as _C_for_compiled_check
  839. if _C_for_compiled_check.__file__ is None:
  840. raise ImportError(
  841. textwrap.dedent(
  842. """
  843. Failed to load PyTorch C extensions:
  844. It appears that PyTorch has loaded the `torch/_C` folder
  845. of the PyTorch repository rather than the C extensions which
  846. are expected in the `torch._C` namespace. This can occur when
  847. using the `install` workflow. e.g.
  848. $ python -m pip install --no-build-isolation -v . && python -c "import torch"
  849. This error can generally be solved using the `develop` workflow
  850. $ python -m pip install --no-build-isolation -v -e . && python -c "import torch" # This should succeed
  851. or by running Python from a different directory.
  852. """
  853. ).strip()
  854. ) from None
  855. raise # If __file__ is not None the cause is unknown, so just re-raise.
  856. # The torch._C submodule is already loaded via `from torch._C import *` above
  857. # Make an explicit reference to the _C submodule to appease linters
  858. from torch import _C as _C
  859. __name, __obj = "", None
  860. for __name in dir(_C):
  861. if __name[0] != "_" and not __name.endswith("Base"):
  862. __all__.append(__name)
  863. __obj = getattr(_C, __name)
  864. if callable(__obj) or inspect.isclass(__obj):
  865. if __obj.__module__ != __name__: # "torch"
  866. # TODO: fix their module from C++ side
  867. if __name not in {
  868. "DisableTorchFunctionSubclass",
  869. "DisableTorchFunction",
  870. "Generator",
  871. }:
  872. __obj.__module__ = __name__ # "torch"
  873. elif __name == "TensorBase":
  874. # issue 109438 / pr 109940. Prevent TensorBase from being copied into torch.
  875. delattr(sys.modules[__name__], __name)
  876. del __name, __obj
  877. if not TYPE_CHECKING:
  878. # issue 38137 and python issue 43367. Submodules of a C extension are
  879. # non-standard, and attributes of those submodules cannot be pickled since
  880. # pickle expect to be able to import them as "from _C.sub import attr"
  881. # which fails with "_C is not a package
  882. def _import_extension_to_sys_modules(module, memo=None):
  883. if memo is None:
  884. memo = set()
  885. if module in memo:
  886. return
  887. memo.add(module)
  888. module_name = module.__name__
  889. for name in dir(module):
  890. member = getattr(module, name)
  891. member_name = getattr(member, "__name__", "")
  892. if inspect.ismodule(member) and member_name.startswith(module_name):
  893. sys.modules.setdefault(member_name, member)
  894. # Recurse for submodules (e.g., `_C._dynamo.eval_frame`)
  895. _import_extension_to_sys_modules(member, memo)
  896. _import_extension_to_sys_modules(_C)
  897. del _import_extension_to_sys_modules
  898. ################################################################################
  899. # Define basic utilities
  900. ################################################################################
  901. def typename(obj: _Any, /) -> str:
  902. """
  903. String representation of the type of an object.
  904. This function returns a fully qualified string representation of an object's type.
  905. Args:
  906. obj (object): The object whose type to represent
  907. Returns:
  908. str: the type of the object `o`
  909. Example:
  910. >>> x = torch.tensor([1, 2, 3])
  911. >>> torch.typename(x)
  912. 'torch.LongTensor'
  913. >>> torch.typename(torch.nn.Parameter)
  914. 'torch.nn.parameter.Parameter'
  915. """
  916. if isinstance(obj, torch.Tensor):
  917. return obj.type()
  918. module = getattr(obj, "__module__", "") or ""
  919. qualname = ""
  920. if hasattr(obj, "__qualname__"):
  921. qualname = obj.__qualname__
  922. elif hasattr(obj, "__name__"):
  923. qualname = obj.__name__
  924. else:
  925. module = obj.__class__.__module__ or ""
  926. qualname = obj.__class__.__qualname__
  927. if module in {"", "builtins"}:
  928. return qualname
  929. return f"{module}.{qualname}"
  930. def is_tensor(obj: _Any, /) -> _TypeIs["torch.Tensor"]:
  931. r"""Returns True if `obj` is a PyTorch tensor.
  932. Args:
  933. obj (object): Object to test
  934. Example::
  935. >>> x = torch.tensor([1, 2, 3])
  936. >>> torch.is_tensor(x)
  937. True
  938. """
  939. return isinstance(obj, torch.Tensor)
  940. def is_storage(obj: _Any, /) -> builtins.bool:
  941. r"""Returns True if `obj` is a PyTorch storage object.
  942. Args:
  943. obj (Object): Object to test
  944. Example::
  945. >>> import torch
  946. >>> # UntypedStorage (recommended)
  947. >>> tensor = torch.tensor([1, 2, 3])
  948. >>> storage = tensor.untyped_storage()
  949. >>> torch.is_storage(storage)
  950. True
  951. >>>
  952. >>> # TypedStorage (legacy)
  953. >>> typed_storage = torch.TypedStorage(5, dtype=torch.float32)
  954. >>> torch.is_storage(typed_storage)
  955. True
  956. >>>
  957. >>> # regular tensor (should return False)
  958. >>> torch.is_storage(tensor)
  959. False
  960. >>>
  961. >>> # non-storage object
  962. >>> torch.is_storage([1, 2, 3])
  963. False
  964. """
  965. return type(obj) in _storage_classes
  966. _GLOBAL_DEVICE_CONTEXT = threading.local()
  967. def get_default_device() -> "torch.device":
  968. r"""Gets the default ``torch.Tensor`` to be allocated on ``device``"""
  969. global _GLOBAL_DEVICE_CONTEXT
  970. from torch.overrides import _get_current_function_mode_stack
  971. from torch.utils._device import DeviceContext
  972. def _get_device_with_index(device):
  973. if device.index is not None:
  974. return device
  975. else:
  976. # TODO: Call like get_device_index() method corresponding to
  977. # each device type
  978. return torch.tensor([]).device
  979. # Get device from any active DeviceContext.
  980. device_mode = next(
  981. filter(
  982. lambda mode: isinstance(mode, DeviceContext),
  983. reversed(_get_current_function_mode_stack()),
  984. ),
  985. None,
  986. )
  987. if device_mode:
  988. device = device_mode.device
  989. return _get_device_with_index(device)
  990. device_context = getattr(_GLOBAL_DEVICE_CONTEXT, "device_context", None)
  991. if device_context is not None:
  992. return _get_device_with_index(device_context.device)
  993. return torch.device("cpu")
  994. def set_default_device(device: "Device") -> None:
  995. """Sets the default ``torch.Tensor`` to be allocated on ``device``. This
  996. does not affect factory function calls which are called with an explicit
  997. ``device`` argument. Factory calls will be performed as if they
  998. were passed ``device`` as an argument.
  999. To only temporarily change the default device instead of setting it
  1000. globally, use ``with torch.device(device):`` instead.
  1001. The default device is initially ``cpu``. If you set the default tensor
  1002. device to another device (e.g., ``cuda``) without a device index, tensors
  1003. will be allocated on whatever the current device for the device type,
  1004. even after :func:`torch.cuda.set_device` is called.
  1005. .. warning::
  1006. This function imposes a slight performance cost on every Python
  1007. call to the torch API (not just factory functions). If this
  1008. is causing problems for you, please comment on
  1009. https://github.com/pytorch/pytorch/issues/92701
  1010. .. note::
  1011. This doesn't affect functions that create tensors that share the same memory as the input, like:
  1012. :func:`torch.from_numpy` and :func:`torch.frombuffer`
  1013. Args:
  1014. device (device or string): the device to set as default
  1015. Example::
  1016. >>> # xdoctest: +SKIP("requires cuda, changes global state")
  1017. >>> torch.get_default_device()
  1018. device(type='cpu')
  1019. >>> torch.set_default_device('cuda') # current device is 0
  1020. >>> torch.get_default_device()
  1021. device(type='cuda', index=0)
  1022. >>> torch.set_default_device('cuda')
  1023. >>> torch.cuda.set_device('cuda:1') # current device is 1
  1024. >>> torch.get_default_device()
  1025. device(type='cuda', index=1)
  1026. >>> torch.set_default_device('cuda:1')
  1027. >>> torch.get_default_device()
  1028. device(type='cuda', index=1)
  1029. """
  1030. global _GLOBAL_DEVICE_CONTEXT
  1031. if hasattr(_GLOBAL_DEVICE_CONTEXT, "device_context"):
  1032. device_context = _GLOBAL_DEVICE_CONTEXT.device_context
  1033. if device_context is not None:
  1034. device_context.__exit__(None, None, None)
  1035. if device is None:
  1036. device_context = None
  1037. else:
  1038. from torch.utils._device import DeviceContext
  1039. device_context = DeviceContext(device)
  1040. device_context.__enter__()
  1041. _GLOBAL_DEVICE_CONTEXT.device_context = device_context
  1042. def set_default_tensor_type(t: type["torch.Tensor"] | str, /) -> None:
  1043. r"""
  1044. .. warning::
  1045. This function is deprecated as of PyTorch 2.1, please use :func:`torch.set_default_dtype()` and
  1046. :func:`torch.set_default_device()` as alternatives.
  1047. Sets the default ``torch.Tensor`` type to floating point tensor type
  1048. ``t``. This type will also be used as default floating point type for
  1049. type inference in :func:`torch.tensor`.
  1050. The default floating point tensor type is initially ``torch.FloatTensor``.
  1051. Args:
  1052. t (type or string): the floating point tensor type or its name
  1053. Example::
  1054. >>> # xdoctest: +SKIP("Other tests may have changed the default type. Can we reset it?")
  1055. >>> torch.tensor([1.2, 3]).dtype # initial default for floating point is torch.float32
  1056. torch.float32
  1057. >>> torch.set_default_tensor_type(torch.DoubleTensor)
  1058. >>> torch.tensor([1.2, 3]).dtype # a new floating point tensor
  1059. torch.float64
  1060. """
  1061. if isinstance(t, str):
  1062. t = _import_dotted_name(t)
  1063. _C._set_default_tensor_type(t)
  1064. def set_default_dtype(d: "torch.dtype", /) -> None:
  1065. r"""
  1066. Sets the default floating point dtype to :attr:`d`. Supports floating point dtype
  1067. as inputs. Other dtypes will cause torch to raise an exception.
  1068. When PyTorch is initialized its default floating point dtype is torch.float32,
  1069. and the intent of set_default_dtype(torch.float64) is to facilitate NumPy-like
  1070. type inference. The default floating point dtype is used to:
  1071. 1. Implicitly determine the default complex dtype. When the default floating type is float16,
  1072. the default complex dtype is complex32. For float32, the default complex dtype is complex64.
  1073. For float64, it is complex128. For bfloat16, an exception will be raised because
  1074. there is no corresponding complex type for bfloat16.
  1075. 2. Infer the dtype for tensors constructed using Python floats or complex Python
  1076. numbers. See examples below.
  1077. 3. Determine the result of type promotion between bool and integer tensors and
  1078. Python floats and complex Python numbers.
  1079. Args:
  1080. d (:class:`torch.dtype`): the floating point dtype to make the default.
  1081. Example:
  1082. >>> # xdoctest: +SKIP("Other tests may have changed the default type. Can we reset it?")
  1083. >>> # initial default for floating point is torch.float32
  1084. >>> # Python floats are interpreted as float32
  1085. >>> torch.tensor([1.2, 3]).dtype
  1086. torch.float32
  1087. >>> # initial default for floating point is torch.complex64
  1088. >>> # Complex Python numbers are interpreted as complex64
  1089. >>> torch.tensor([1.2, 3j]).dtype
  1090. torch.complex64
  1091. >>> torch.set_default_dtype(torch.float64)
  1092. >>> # Python floats are now interpreted as float64
  1093. >>> torch.tensor([1.2, 3]).dtype # a new floating point tensor
  1094. torch.float64
  1095. >>> # Complex Python numbers are now interpreted as complex128
  1096. >>> torch.tensor([1.2, 3j]).dtype # a new complex tensor
  1097. torch.complex128
  1098. >>> torch.set_default_dtype(torch.float16)
  1099. >>> # Python floats are now interpreted as float16
  1100. >>> torch.tensor([1.2, 3]).dtype # a new floating point tensor
  1101. torch.float16
  1102. >>> # Complex Python numbers are now interpreted as complex128
  1103. >>> torch.tensor([1.2, 3j]).dtype # a new complex tensor
  1104. torch.complex32
  1105. """
  1106. _C._set_default_dtype(d)
  1107. def use_deterministic_algorithms(
  1108. mode: builtins.bool,
  1109. *,
  1110. warn_only: builtins.bool = False,
  1111. ) -> None:
  1112. r"""Sets whether PyTorch operations must use "deterministic"
  1113. algorithms. That is, algorithms which, given the same input, and when
  1114. run on the same software and hardware, always produce the same output.
  1115. When enabled, operations will use deterministic algorithms when available,
  1116. and if only nondeterministic algorithms are available they will throw a
  1117. :class:`RuntimeError` when called.
  1118. .. note:: This setting alone is not always enough to make an application
  1119. reproducible. Refer to :ref:`reproducibility` for more information.
  1120. .. note:: :func:`torch.set_deterministic_debug_mode` offers an alternative
  1121. interface for this feature.
  1122. The following normally-nondeterministic operations will act
  1123. deterministically when ``mode=True``:
  1124. * :class:`torch.nn.Conv1d` when called on CUDA tensor
  1125. * :class:`torch.nn.Conv2d` when called on CUDA tensor
  1126. * :class:`torch.nn.Conv3d` when called on CUDA tensor
  1127. * :class:`torch.nn.ConvTranspose1d` when called on CUDA tensor
  1128. * :class:`torch.nn.ConvTranspose2d` when called on CUDA tensor
  1129. * :class:`torch.nn.ConvTranspose3d` when called on CUDA tensor
  1130. * :class:`torch.nn.ReplicationPad1d` when attempting to differentiate a CUDA tensor
  1131. * :class:`torch.nn.ReplicationPad2d` when attempting to differentiate a CUDA tensor
  1132. * :class:`torch.nn.ReplicationPad3d` when attempting to differentiate a CUDA tensor
  1133. * :func:`torch.bmm` when called on sparse-dense CUDA tensors
  1134. * :func:`torch.Tensor.__getitem__` when attempting to differentiate a CPU tensor
  1135. and the index is a list of tensors
  1136. * :func:`torch.Tensor.index_put` with ``accumulate=False``
  1137. * :func:`torch.Tensor.index_put` with ``accumulate=True`` when called on a CPU
  1138. tensor
  1139. * :func:`torch.Tensor.put_` with ``accumulate=True`` when called on a CPU
  1140. tensor
  1141. * :func:`torch.Tensor.scatter_add_` when called on a CUDA tensor
  1142. * :func:`torch.gather` when called on a CUDA tensor that requires grad
  1143. * :func:`torch.index_add` when called on CUDA tensor
  1144. * :func:`torch.index_select` when attempting to differentiate a CUDA tensor
  1145. * :func:`torch.repeat_interleave` when attempting to differentiate a CUDA tensor
  1146. * :func:`torch.Tensor.index_copy` when called on a CPU or CUDA tensor
  1147. * :func:`torch.Tensor.scatter` when `src` type is Tensor and called on CUDA tensor
  1148. * :func:`torch.Tensor.scatter_reduce` when ``reduce='sum'`` or ``reduce='mean'`` and called on CUDA tensor
  1149. The following normally-nondeterministic operations will throw a
  1150. :class:`RuntimeError` when ``mode=True``:
  1151. * :class:`torch.nn.AvgPool3d` when attempting to differentiate a CUDA tensor
  1152. * :class:`torch.nn.AdaptiveAvgPool2d` when attempting to differentiate a CUDA tensor
  1153. * :class:`torch.nn.AdaptiveAvgPool3d` when attempting to differentiate a CUDA tensor
  1154. * :class:`torch.nn.MaxPool3d` when attempting to differentiate a CUDA tensor
  1155. * :class:`torch.nn.AdaptiveMaxPool2d` when attempting to differentiate a CUDA tensor
  1156. * :class:`torch.nn.FractionalMaxPool2d` when attempting to differentiate a CUDA tensor
  1157. * :class:`torch.nn.FractionalMaxPool3d` when attempting to differentiate a CUDA tensor
  1158. * :class:`torch.nn.MaxUnpool1d`
  1159. * :class:`torch.nn.MaxUnpool2d`
  1160. * :class:`torch.nn.MaxUnpool3d`
  1161. * :func:`torch.nn.functional.interpolate` when attempting to differentiate a CUDA tensor
  1162. and one of the following modes is used:
  1163. - ``linear``
  1164. - ``bilinear``
  1165. - ``bicubic``
  1166. - ``trilinear``
  1167. * :class:`torch.nn.ReflectionPad1d` when attempting to differentiate a CUDA tensor
  1168. * :class:`torch.nn.ReflectionPad2d` when attempting to differentiate a CUDA tensor
  1169. * :class:`torch.nn.ReflectionPad3d` when attempting to differentiate a CUDA tensor
  1170. * :class:`torch.nn.NLLLoss` when called on a CUDA tensor
  1171. * :class:`torch.nn.CTCLoss` when attempting to differentiate a CUDA tensor
  1172. * :class:`torch.nn.EmbeddingBag` when attempting to differentiate a CUDA tensor when
  1173. ``mode='max'``
  1174. * :func:`torch.Tensor.put_` when ``accumulate=False``
  1175. * :func:`torch.Tensor.put_` when ``accumulate=True`` and called on a CUDA tensor
  1176. * :func:`torch.histc` when called on a CUDA tensor
  1177. * :func:`torch.bincount` when called on a CUDA tensor and ``weights``
  1178. tensor is given
  1179. * :func:`torch.median` with indices output when called on a CUDA tensor
  1180. * :func:`torch.nn.functional.grid_sample` when attempting to differentiate a CUDA tensor
  1181. * :func:`torch.cumsum` when called on a CUDA tensor when dtype is floating point or complex
  1182. * :func:`torch.Tensor.scatter_reduce` when ``reduce='prod'`` and called on CUDA tensor
  1183. * :func:`torch.Tensor.resize_` when called with a quantized tensor
  1184. In addition, several operations fill uninitialized memory when this setting
  1185. is turned on and when
  1186. :attr:`torch.utils.deterministic.fill_uninitialized_memory` is turned on.
  1187. See the documentation for that attribute for more information.
  1188. Note that deterministic operations tend to have worse performance than
  1189. nondeterministic operations.
  1190. When this setting is turned on, the Inductor deterministic mode is also tuned on
  1191. automatically. In deterministic mode, Inductor would avoid doing on device benchmarking
  1192. that affect numerics. This includes:
  1193. - don't pad matmul input shapes. Without enabling deterministic mode, Inductor would do
  1194. benchmarking to check if padding matmul shape is beneficial.
  1195. - don't autotune templates. Inductor has templates for kernels like matmul/conv/attention.
  1196. Without enabling deterministic mode, Inductor would do autotuning to
  1197. pick the best configs for those templates and adopt it if it's faster
  1198. than the kernel in eager mode. In deterministic mode, we pick the eager kernel.
  1199. - don't autotune triton configs for reduction. Reduction numerics are
  1200. very sensitive to triton configs. In deterministic mode, Inductor
  1201. will use some heuristics to pick the most promising configs rather
  1202. than do autotuning.
  1203. - Skip autotuning for reduction in coordinate descent tuning.
  1204. - Don't benchmarking for the computation/communication reordering pass
  1205. - Disable the feature that dynamically scale down RBLOCK triton config for higher
  1206. occupancy.
  1207. .. note::
  1208. This flag does not detect or prevent nondeterministic behavior caused
  1209. by calling an inplace operation on a tensor with an internal memory
  1210. overlap or by giving such a tensor as the :attr:`out` argument for an
  1211. operation. In these cases, multiple writes of different data may target
  1212. a single memory location, and the order of writes is not guaranteed.
  1213. Args:
  1214. mode (:class:`bool`): If True, makes potentially nondeterministic
  1215. operations switch to a deterministic algorithm or throw a runtime
  1216. error. If False, allows nondeterministic operations.
  1217. Keyword args:
  1218. warn_only (:class:`bool`, optional): If True, operations that do not
  1219. have a deterministic implementation will throw a warning instead of
  1220. an error. Default: ``False``
  1221. Example::
  1222. >>> # xdoctest: +SKIP
  1223. >>> torch.use_deterministic_algorithms(True)
  1224. # Backward mode nondeterministic error
  1225. >>> torch.nn.AvgPool3d(1)(torch.randn(3, 4, 5, 6, requires_grad=True).cuda()).sum().backward()
  1226. ...
  1227. RuntimeError: avg_pool3d_backward_cuda does not have a deterministic implementation...
  1228. """
  1229. import torch._inductor.config as inductor_config
  1230. inductor_config.deterministic = mode
  1231. _C._set_deterministic_algorithms(mode, warn_only=warn_only)
  1232. def are_deterministic_algorithms_enabled() -> builtins.bool:
  1233. r"""Returns True if the global deterministic flag is turned on. Refer to
  1234. :func:`torch.use_deterministic_algorithms` documentation for more details.
  1235. """
  1236. return _C._get_deterministic_algorithms()
  1237. def is_deterministic_algorithms_warn_only_enabled() -> builtins.bool:
  1238. r"""Returns True if the global deterministic flag is set to warn only.
  1239. Refer to :func:`torch.use_deterministic_algorithms` documentation for more
  1240. details.
  1241. """
  1242. return _C._get_deterministic_algorithms_warn_only()
  1243. def set_deterministic_debug_mode(debug_mode: builtins.int | str) -> None:
  1244. r"""Sets the debug mode for deterministic operations.
  1245. .. note:: This is an alternative interface for
  1246. :func:`torch.use_deterministic_algorithms`. Refer to that function's
  1247. documentation for details about affected operations.
  1248. Args:
  1249. debug_mode(str or int): If "default" or 0, don't error or warn on
  1250. nondeterministic operations. If "warn" or 1, warn on
  1251. nondeterministic operations. If "error" or 2, error on
  1252. nondeterministic operations.
  1253. """
  1254. # NOTE: builtins.int is used here because int in this scope resolves
  1255. # to torch.int
  1256. if not isinstance(debug_mode, (builtins.int, str)):
  1257. raise TypeError(f"debug_mode must be str or int, but got {type(debug_mode)}")
  1258. if isinstance(debug_mode, str):
  1259. if debug_mode == "default":
  1260. debug_mode = 0
  1261. elif debug_mode == "warn":
  1262. debug_mode = 1
  1263. elif debug_mode == "error":
  1264. debug_mode = 2
  1265. else:
  1266. raise RuntimeError(
  1267. "invalid value of debug_mode, expected one of `default`, "
  1268. f"`warn`, `error`, but got {debug_mode}"
  1269. )
  1270. if debug_mode == 0:
  1271. _C._set_deterministic_algorithms(False)
  1272. elif debug_mode == 1:
  1273. _C._set_deterministic_algorithms(True, warn_only=True)
  1274. elif debug_mode == 2:
  1275. _C._set_deterministic_algorithms(True)
  1276. else:
  1277. raise RuntimeError(
  1278. f"invalid value of debug_mode, expected 0, 1, or 2, but got {debug_mode}"
  1279. )
  1280. def get_deterministic_debug_mode() -> builtins.int:
  1281. r"""Returns the current value of the debug mode for deterministic
  1282. operations. Refer to :func:`torch.set_deterministic_debug_mode`
  1283. documentation for more details.
  1284. """
  1285. if _C._get_deterministic_algorithms():
  1286. if _C._get_deterministic_algorithms_warn_only():
  1287. return 1
  1288. else:
  1289. return 2
  1290. else:
  1291. return 0
  1292. def get_float32_matmul_precision() -> str:
  1293. r"""Returns the current value of float32 matrix multiplication precision. Refer to
  1294. :func:`torch.set_float32_matmul_precision` documentation for more details.
  1295. """
  1296. return _C._get_float32_matmul_precision()
  1297. def set_float32_matmul_precision(precision: str) -> None:
  1298. r"""Sets the internal precision of float32 matrix multiplications.
  1299. Running float32 matrix multiplications in lower precision may significantly increase
  1300. performance, and in some programs the loss of precision has a negligible impact.
  1301. Supports three settings:
  1302. * "highest", float32 matrix multiplications use the float32 datatype (24 mantissa
  1303. bits with 23 bits explicitly stored) for internal computations.
  1304. * "high", float32 matrix multiplications either use the TensorFloat32 datatype (10
  1305. mantissa bits explicitly stored) or treat each float32 number as the sum of two bfloat16 numbers
  1306. (approximately 16 mantissa bits with 14 bits explicitly stored), if the appropriate fast matrix multiplication
  1307. algorithms are available. Otherwise float32 matrix multiplications are computed
  1308. as if the precision is "highest". See below for more information on the bfloat16
  1309. approach.
  1310. * "medium", float32 matrix multiplications use the bfloat16 datatype (8 mantissa
  1311. bits with 7 bits explicitly stored) for internal computations, if a fast matrix multiplication algorithm
  1312. using that datatype internally is available. Otherwise float32
  1313. matrix multiplications are computed as if the precision is "high".
  1314. When using "high" precision, float32 multiplications may use a bfloat16-based algorithm
  1315. that is more complicated than simply truncating to some smaller number mantissa bits
  1316. (e.g. 10 for TensorFloat32, 7 for bfloat16 explicitly stored). Refer to [Henry2019]_ for a complete
  1317. description of this algorithm. To briefly explain here, the first step is to realize
  1318. that we can perfectly encode a single float32 number as the sum of three bfloat16
  1319. numbers (because float32 has 23 mantissa bits while bfloat16 has 7 explicitly stored, and both have the
  1320. same number of exponent bits). This means that the product of two float32 numbers can
  1321. be exactly given by the sum of nine products of bfloat16 numbers. We can then trade
  1322. accuracy for speed by dropping some of these products. The "high" precision algorithm
  1323. specifically keeps only the three most significant products, which conveniently excludes
  1324. all of the products involving the last 8 mantissa bits of either input. This means that
  1325. we can represent our inputs as the sum of two bfloat16 numbers rather than three.
  1326. Because bfloat16 fused-multiply-add (FMA) instructions are typically >10x faster than
  1327. float32 ones, it's faster to do three multiplications and 2 additions with bfloat16
  1328. precision than it is to do a single multiplication with float32 precision.
  1329. .. [Henry2019] http://arxiv.org/abs/1904.06376
  1330. .. note::
  1331. This does not change the output dtype of float32 matrix multiplications,
  1332. it controls how the internal computation of the matrix multiplication is performed.
  1333. .. note::
  1334. This does not change the precision of convolution operations. Other flags,
  1335. like `torch.backends.cudnn.allow_tf32`, may control the precision of convolution
  1336. operations.
  1337. .. note::
  1338. This flag currently only affects one native device type: CUDA.
  1339. If "high" or "medium" are set then the TensorFloat32 datatype will be used
  1340. when computing float32 matrix multiplications, equivalent to setting
  1341. `torch.backends.cuda.matmul.allow_tf32 = True`. When "highest" (the default)
  1342. is set then the float32 datatype is used for internal computations, equivalent
  1343. to setting `torch.backends.cuda.matmul.allow_tf32 = False`.
  1344. Args:
  1345. precision(str): can be set to "highest" (default), "high", or "medium" (see above).
  1346. """
  1347. _C._set_float32_matmul_precision(precision)
  1348. def set_warn_always(b: builtins.bool, /) -> None:
  1349. r"""When this flag is False (default) then some PyTorch warnings may only
  1350. appear once per process. This helps avoid excessive warning information.
  1351. Setting it to True causes these warnings to always appear, which may be
  1352. helpful when debugging.
  1353. Args:
  1354. b (:class:`bool`): If True, force warnings to always be emitted
  1355. If False, set to the default behaviour
  1356. """
  1357. _C._set_warnAlways(b)
  1358. def is_warn_always_enabled() -> builtins.bool:
  1359. r"""Returns True if the global warn_always flag is turned on. Refer to
  1360. :func:`torch.set_warn_always` documentation for more details.
  1361. """
  1362. return _C._get_warnAlways()
  1363. ################################################################################
  1364. # Define error checking functions
  1365. ################################################################################
  1366. # These error checking functions must be kept consistent with their C++
  1367. # equivalents. Their C++ equivalents are mentioned where applicable.
  1368. def _check_with(
  1369. error_type,
  1370. cond: builtins.bool | SymBool,
  1371. message: _Callable[[], str],
  1372. ): # noqa: F811
  1373. if not isinstance(cond, (builtins.bool, SymBool)):
  1374. raise TypeError(f"cond must be a bool, but got {type(cond)}")
  1375. from torch.fx.experimental.symbolic_shapes import expect_true
  1376. if expect_true(cond):
  1377. return
  1378. # error_type must be a subclass of Exception and not subclass of Warning
  1379. assert issubclass(error_type, Exception) and not issubclass(error_type, Warning)
  1380. if message is None:
  1381. message_evaluated = (
  1382. "Expected cond to be True, but got False. (Could this error "
  1383. "message be improved? If so, please report an enhancement request "
  1384. "to PyTorch.)"
  1385. )
  1386. else:
  1387. if not callable(message):
  1388. raise TypeError("message must be a callable")
  1389. message_evaluated = str(message())
  1390. raise error_type(message_evaluated)
  1391. def _check(cond, message=None): # noqa: F811
  1392. r"""Throws error containing an optional message if the specified condition
  1393. is False.
  1394. Error type: ``RuntimeError``
  1395. C++ equivalent: ``TORCH_CHECK``
  1396. Args:
  1397. cond (:class:`bool`): If False, throw error
  1398. message (Callable, optional): Callable that returns either a string or
  1399. an object that has a ``__str__()`` method to be used as the error
  1400. message. Default: ``None``
  1401. """
  1402. _check_with(RuntimeError, cond, message) # pyrefly: ignore [bad-argument-type]
  1403. # TODO add deprecation annotation
  1404. def _check_is_size(i, message=None, *, max=None):
  1405. """Checks that a given integer is a valid size (i.e., is non-negative).
  1406. You should use this over ``_check(i >= 0)`` because it can prevent
  1407. ``GuardOnDataDependentSymNode`` exceptions by opting yourself into alternate
  1408. semantics for ``guard_size_oblivious`` tests that treat values 0 and 1
  1409. equivalently to all other values.
  1410. When max is not None, this specifies an upper bound equivalent to
  1411. ``_check(i <= max)``. This bound is also subject to alternate semantics:
  1412. in ``guard_size_oblivious`` tests, we assume that a constant max bound is
  1413. treated equivalently to all other values. Symbolic max bounds are not yet
  1414. supported.
  1415. NB: Do NOT use this in contexts where a -1 size would be valid (indicating
  1416. to infer the size from context, or if you should wrap-around or truncate).
  1417. Only use this if the only valid value is an honest to goodness size.
  1418. """
  1419. # This is responsible for the expect_true
  1420. _check(i >= 0, message)
  1421. from torch.fx.experimental.symbolic_shapes import _advise_is_size
  1422. _advise_is_size(i)
  1423. if max is not None:
  1424. _check(i <= max, message)
  1425. from torch.fx.experimental.symbolic_shapes import _advise_is_bounded
  1426. _advise_is_bounded(i, max)
  1427. def _check_index(cond, message=None): # noqa: F811
  1428. r"""Throws error containing an optional message if the specified condition
  1429. is False.
  1430. Error type: ``IndexError``
  1431. C++ equivalent: ``TORCH_CHECK_INDEX``
  1432. Args:
  1433. cond (:class:`bool`): If False, throw error
  1434. message (Callable, optional): Callable that returns either a string or
  1435. an object that has a ``__str__()`` method to be used as the error
  1436. message. Default: ``None``
  1437. """
  1438. _check_with(IndexError, cond, message) # pyrefly: ignore [bad-argument-type]
  1439. def _check_value(cond, message=None): # noqa: F811
  1440. r"""Throws error containing an optional message if the specified condition
  1441. is False.
  1442. Error type: ``ValueError``
  1443. C++ equivalent: ``TORCH_CHECK_VALUE``
  1444. Args:
  1445. cond (:class:`bool`): If False, throw error
  1446. message (Callable, optional): Callable that returns either a string or
  1447. an object that has a ``__str__()`` method to be used as the error
  1448. message. Default: ``None``
  1449. """
  1450. _check_with(ValueError, cond, message) # pyrefly: ignore [bad-argument-type]
  1451. def _check_type(cond, message=None): # noqa: F811
  1452. r"""Throws error containing an optional message if the specified condition
  1453. is False.
  1454. Error type: ``TypeError``
  1455. C++ equivalent: ``TORCH_CHECK_TYPE``
  1456. Args:
  1457. cond (:class:`bool`): If False, throw error
  1458. message (Callable, optional): Callable that returns either a string or
  1459. an object that has a ``__str__()`` method to be used as the error
  1460. message. Default: ``None``
  1461. """
  1462. _check_with(TypeError, cond, message) # pyrefly: ignore [bad-argument-type]
  1463. def _check_not_implemented(cond, message=None): # noqa: F811
  1464. r"""Throws error containing an optional message if the specified condition
  1465. is False.
  1466. Error type: ``NotImplementedError``
  1467. C++ equivalent: ``TORCH_CHECK_NOT_IMPLEMENTED``
  1468. Args:
  1469. cond (:class:`bool`): If False, throw error
  1470. message (Callable, optional): Callable that returns either a string or
  1471. an object that has a ``__str__()`` method to be used as the error
  1472. message. Default: ``None``
  1473. """
  1474. _check_with(
  1475. NotImplementedError,
  1476. cond,
  1477. # pyrefly: ignore [bad-argument-type]
  1478. message,
  1479. )
  1480. def _check_tensor_all_with(error_type, cond, message=None): # noqa: F811
  1481. if not is_tensor(cond):
  1482. raise TypeError(f"cond must be a tensor, but got {type(cond)}")
  1483. if not cond.dtype == torch.bool:
  1484. raise TypeError(f"cond tensor must have dtype torch.bool, but got {cond.dtype}")
  1485. _check_with(error_type, cond._is_all_true().item(), message) # type: ignore[arg-type]
  1486. # C++ equivalent: `TORCH_CHECK_TENSOR_ALL`
  1487. def _check_tensor_all(cond, message=None): # noqa: F811
  1488. r"""Throws error containing an optional message if the specified condition
  1489. is False.
  1490. Error type: ``RuntimeError``
  1491. C++ equivalent: ``TORCH_CHECK_TENSOR_ALL``
  1492. Args:
  1493. cond (:class:`torch.Tensor`): Tensor of dtype ``torch.bool``. If any
  1494. element is ``False``, throw error
  1495. message (Callable, optional): Callable that returns either a string or
  1496. an object that has a ``__str__()`` method to be used as the error
  1497. message. Default: ``None``
  1498. """
  1499. _check_tensor_all_with(RuntimeError, cond, message)
  1500. ################################################################################
  1501. # Define numeric constants
  1502. ################################################################################
  1503. # For Python Array API (https://data-apis.org/array-api/latest/API_specification/constants.html) and
  1504. # NumPy consistency (https://numpy.org/devdocs/reference/constants.html)
  1505. from math import e, inf, nan, pi
  1506. newaxis: None = None
  1507. __all__.extend(["e", "pi", "nan", "inf", "newaxis"])
  1508. ################################################################################
  1509. # Define Storage and Tensor classes
  1510. ################################################################################
  1511. from torch._tensor import Tensor # usort: skip
  1512. # needs to be after torch.Tensor is defined to avoid circular dependencies
  1513. from torch import storage as storage # usort: skip
  1514. from torch.storage import (
  1515. _LegacyStorage,
  1516. _StorageBase,
  1517. _warn_typed_storage_removal,
  1518. TypedStorage,
  1519. UntypedStorage,
  1520. )
  1521. # NOTE: New <type>Storage classes should never be added. When adding a new
  1522. # dtype, use torch.storage.TypedStorage directly.
  1523. class ByteStorage(_LegacyStorage):
  1524. @classproperty
  1525. def dtype(self):
  1526. _warn_typed_storage_removal(stacklevel=3)
  1527. return self._dtype
  1528. @classproperty
  1529. def _dtype(self):
  1530. return torch.uint8
  1531. class DoubleStorage(_LegacyStorage):
  1532. @classproperty
  1533. def dtype(self):
  1534. _warn_typed_storage_removal(stacklevel=3)
  1535. return self._dtype
  1536. @classproperty
  1537. def _dtype(self):
  1538. return torch.double
  1539. class FloatStorage(_LegacyStorage):
  1540. @classproperty
  1541. def dtype(self):
  1542. _warn_typed_storage_removal(stacklevel=3)
  1543. return self._dtype
  1544. @classproperty
  1545. def _dtype(self):
  1546. return torch.float
  1547. class HalfStorage(_LegacyStorage):
  1548. @classproperty
  1549. def dtype(self):
  1550. _warn_typed_storage_removal(stacklevel=3)
  1551. return self._dtype
  1552. @classproperty
  1553. def _dtype(self):
  1554. return torch.half
  1555. class LongStorage(_LegacyStorage):
  1556. @classproperty
  1557. def dtype(self):
  1558. _warn_typed_storage_removal(stacklevel=3)
  1559. return self._dtype
  1560. @classproperty
  1561. def _dtype(self):
  1562. return torch.long
  1563. class IntStorage(_LegacyStorage):
  1564. @classproperty
  1565. def dtype(self):
  1566. _warn_typed_storage_removal(stacklevel=3)
  1567. return self._dtype
  1568. @classproperty
  1569. def _dtype(self):
  1570. return torch.int
  1571. class ShortStorage(_LegacyStorage):
  1572. @classproperty
  1573. def dtype(self):
  1574. _warn_typed_storage_removal(stacklevel=3)
  1575. return self._dtype
  1576. @classproperty
  1577. def _dtype(self):
  1578. return torch.short
  1579. class CharStorage(_LegacyStorage):
  1580. @classproperty
  1581. def dtype(self):
  1582. _warn_typed_storage_removal(stacklevel=3)
  1583. return self._dtype
  1584. @classproperty
  1585. def _dtype(self):
  1586. return torch.int8
  1587. class BoolStorage(_LegacyStorage):
  1588. @classproperty
  1589. def dtype(self):
  1590. _warn_typed_storage_removal(stacklevel=3)
  1591. return self._dtype
  1592. @classproperty
  1593. def _dtype(self):
  1594. return torch.bool
  1595. class BFloat16Storage(_LegacyStorage):
  1596. @classproperty
  1597. def dtype(self):
  1598. _warn_typed_storage_removal(stacklevel=3)
  1599. return self._dtype
  1600. @classproperty
  1601. def _dtype(self):
  1602. return torch.bfloat16
  1603. class ComplexDoubleStorage(_LegacyStorage):
  1604. @classproperty
  1605. def dtype(self):
  1606. _warn_typed_storage_removal(stacklevel=3)
  1607. return self._dtype
  1608. @classproperty
  1609. def _dtype(self):
  1610. return torch.cdouble
  1611. class ComplexFloatStorage(_LegacyStorage):
  1612. @classproperty
  1613. def dtype(self):
  1614. _warn_typed_storage_removal(stacklevel=3)
  1615. return self._dtype
  1616. @classproperty
  1617. def _dtype(self):
  1618. return torch.cfloat
  1619. class QUInt8Storage(_LegacyStorage):
  1620. @classproperty
  1621. def dtype(self):
  1622. _warn_typed_storage_removal(stacklevel=3)
  1623. return self._dtype
  1624. @classproperty
  1625. def _dtype(self):
  1626. return torch.quint8
  1627. class QInt8Storage(_LegacyStorage):
  1628. @classproperty
  1629. def dtype(self):
  1630. _warn_typed_storage_removal(stacklevel=3)
  1631. return self._dtype
  1632. @classproperty
  1633. def _dtype(self):
  1634. return torch.qint8
  1635. class QInt32Storage(_LegacyStorage):
  1636. @classproperty
  1637. def dtype(self):
  1638. _warn_typed_storage_removal(stacklevel=3)
  1639. return self._dtype
  1640. @classproperty
  1641. def _dtype(self):
  1642. return torch.qint32
  1643. class QUInt4x2Storage(_LegacyStorage):
  1644. @classproperty
  1645. def dtype(self):
  1646. _warn_typed_storage_removal(stacklevel=3)
  1647. return self._dtype
  1648. @classproperty
  1649. def _dtype(self):
  1650. return torch.quint4x2
  1651. class QUInt2x4Storage(_LegacyStorage):
  1652. @classproperty
  1653. def dtype(self):
  1654. _warn_typed_storage_removal(stacklevel=3)
  1655. return self._dtype
  1656. @classproperty
  1657. def _dtype(self):
  1658. return torch.quint2x4
  1659. _storage_classes: set[type[TypedStorage | UntypedStorage]] = {
  1660. UntypedStorage,
  1661. DoubleStorage,
  1662. FloatStorage,
  1663. LongStorage,
  1664. IntStorage,
  1665. ShortStorage,
  1666. CharStorage,
  1667. ByteStorage,
  1668. HalfStorage,
  1669. BoolStorage,
  1670. QUInt8Storage,
  1671. QInt8Storage,
  1672. QInt32Storage,
  1673. BFloat16Storage,
  1674. ComplexFloatStorage,
  1675. ComplexDoubleStorage,
  1676. QUInt4x2Storage,
  1677. QUInt2x4Storage,
  1678. TypedStorage,
  1679. }
  1680. # The _tensor_classes set is initialized by the call to initialize_python_bindings.
  1681. _tensor_classes: set[type["torch.Tensor"]] = set()
  1682. # If you edit these imports, please update torch/__init__.py.in as well
  1683. from torch import amp as amp, random as random, serialization as serialization
  1684. from torch._tensor_str import set_printoptions
  1685. from torch.amp import autocast, GradScaler
  1686. from torch.random import get_rng_state, initial_seed, manual_seed, seed, set_rng_state
  1687. from torch.serialization import load, save
  1688. ################################################################################
  1689. # Initialize extension
  1690. ################################################################################
  1691. # Shared memory manager needs to know the exact location of manager executable
  1692. def _manager_path():
  1693. if platform.system() == "Windows":
  1694. return b""
  1695. path = get_file_path("torch", "bin", "torch_shm_manager")
  1696. prepare_multiprocessing_environment(get_file_path("torch"))
  1697. if not os.path.exists(path):
  1698. raise RuntimeError("Unable to find torch_shm_manager at " + path)
  1699. return path.encode("utf-8")
  1700. _C._initExtension(_manager_path())
  1701. del _manager_path
  1702. # Appease the type checker: it can't deal with direct setting of globals().
  1703. # Note that we will see "too many" functions when reexporting this way; there
  1704. # is not a good way to fix this problem. Perhaps, try to redesign VariableFunctions
  1705. # so that this import is good enough
  1706. if TYPE_CHECKING:
  1707. # Some type signatures pulled in from _VariableFunctions here clash with
  1708. # signatures already imported. For now these clashes are ignored; see
  1709. # PR #43339 for details.
  1710. from torch._C._VariableFunctions import * # type: ignore[assignment, misc] # noqa: F403
  1711. # Fixup segment_reduce visibility
  1712. _segment_reduce = segment_reduce
  1713. del segment_reduce # noqa: F821
  1714. # Ops not to be exposed in `torch` namespace,
  1715. # mostly helper ops.
  1716. PRIVATE_OPS = ("unique_dim",)
  1717. __name, __obj = "", None
  1718. for __name in dir(_C._VariableFunctions):
  1719. if __name.startswith("__") or __name in PRIVATE_OPS:
  1720. continue
  1721. __obj = getattr(_C._VariableFunctions, __name)
  1722. __obj.__module__ = __name__ # "torch"
  1723. # Hide some APIs that should not be public
  1724. if __name == "segment_reduce":
  1725. # TODO: Once the undocumented FC window is passed, remove the line below
  1726. globals()[__name] = __obj
  1727. __name = "_" + __name
  1728. globals()[__name] = __obj
  1729. if not __name.startswith("_"):
  1730. __all__.append(__name)
  1731. del __name, __obj
  1732. ################################################################################
  1733. # Add torch.dtype instances to the public API
  1734. ################################################################################
  1735. import torch
  1736. __all__.extend(
  1737. name for name in dir(torch) if isinstance(getattr(torch, name), torch.dtype)
  1738. )
  1739. ################################################################################
  1740. # Import TorchDynamo's lazy APIs to avoid circular dependencies
  1741. ################################################################################
  1742. # needs to be before from torch.functional import * to avoid circular dependencies
  1743. from torch._compile import _disable_dynamo # usort: skip
  1744. ################################################################################
  1745. # Import interface functions defined in Python
  1746. ################################################################################
  1747. # needs to be after the above ATen bindings so we can overwrite from Python side
  1748. from torch import _VF as _VF, functional as functional # usort: skip
  1749. from torch.functional import * # usort: skip # noqa: F403
  1750. ################################################################################
  1751. # Remove unnecessary members
  1752. ################################################################################
  1753. del _StorageBase
  1754. del _LegacyStorage
  1755. ################################################################################
  1756. # Define _assert
  1757. ################################################################################
  1758. # needs to be before the submodule imports to avoid circular dependencies
  1759. def _assert(condition, message):
  1760. r"""A wrapper around Python's assert which is symbolically traceable."""
  1761. if type(condition) is not torch.Tensor and overrides.has_torch_function(
  1762. (condition,)
  1763. ):
  1764. return overrides.handle_torch_function(
  1765. _assert, (condition,), condition, message
  1766. )
  1767. assert condition, message
  1768. ################################################################################
  1769. # Import most common subpackages
  1770. ################################################################################
  1771. # Use the redundant form so that type checkers know that these are a part of
  1772. # the public API. The "regular" import lines are there solely for the runtime
  1773. # side effect of adding to the imported module's members for other users.
  1774. # needs to be before import torch.nn as nn to avoid circular dependencies
  1775. from torch.autograd import ( # usort: skip
  1776. enable_grad as enable_grad,
  1777. inference_mode as inference_mode,
  1778. no_grad as no_grad,
  1779. set_grad_enabled as set_grad_enabled,
  1780. )
  1781. from torch import (
  1782. __config__ as __config__,
  1783. __future__ as __future__,
  1784. _awaits as _awaits,
  1785. accelerator as accelerator,
  1786. autograd as autograd,
  1787. backends as backends,
  1788. cpu as cpu,
  1789. cuda as cuda,
  1790. distributed as distributed,
  1791. distributions as distributions,
  1792. fft as fft,
  1793. futures as futures,
  1794. hub as hub,
  1795. jit as jit,
  1796. linalg as linalg,
  1797. mps as mps,
  1798. mtia as mtia,
  1799. multiprocessing as multiprocessing,
  1800. nested as nested,
  1801. nn as nn,
  1802. optim as optim,
  1803. overrides as overrides,
  1804. profiler as profiler,
  1805. sparse as sparse,
  1806. special as special,
  1807. testing as testing,
  1808. types as types,
  1809. utils as utils,
  1810. version as version,
  1811. xpu as xpu,
  1812. )
  1813. from torch.signal import windows as windows
  1814. # Quantized, sparse, AO, etc. should be last to get imported, as nothing
  1815. # is expected to depend on them.
  1816. from torch import ao as ao # usort: skip
  1817. # nn.quant* depends on ao -- so should be after those.
  1818. import torch.nn.intrinsic
  1819. import torch.nn.qat
  1820. import torch.nn.quantizable
  1821. import torch.nn.quantized
  1822. _C._init_names(list(_storage_classes))
  1823. # attach docstrings to torch and tensor functions
  1824. from torch import _size_docs, _storage_docs, _tensor_docs, _torch_docs
  1825. del _torch_docs, _tensor_docs, _storage_docs, _size_docs
  1826. def compiled_with_cxx11_abi() -> builtins.bool:
  1827. r"""Returns whether PyTorch was built with _GLIBCXX_USE_CXX11_ABI=1"""
  1828. return True
  1829. from torch import _library as _library, _ops as _ops
  1830. # Import the ops and classes "namespace"
  1831. from torch._ops import ops as ops # usort: skip
  1832. from torch._classes import classes as classes # usort: skip
  1833. sys.modules.setdefault(f"{__name__}.ops", ops)
  1834. sys.modules.setdefault(f"{__name__}.classes", classes)
  1835. # quantization depends on torch.fx and torch.ops
  1836. # Import quantization
  1837. from torch import quantization as quantization # usort: skip
  1838. # Import the quasi random sampler
  1839. from torch import quasirandom as quasirandom # usort: skip
  1840. # If you are seeing this, it means that this call site was not checked if
  1841. # the memory format could be preserved, and it was switched to old default
  1842. # behaviour of contiguous
  1843. legacy_contiguous_format = contiguous_format # defined by _C._initExtension()
  1844. # Register fork handler to initialize OpenMP in child processes (see gh-28389)
  1845. from torch.multiprocessing._atfork import register_after_fork
  1846. register_after_fork(torch.get_num_threads)
  1847. del register_after_fork
  1848. # Import tools that require fully imported torch (for applying
  1849. # torch.jit.script as a decorator, for instance):
  1850. from torch._lobpcg import lobpcg as lobpcg
  1851. # These were previously defined in native_functions.yaml and appeared on the
  1852. # `torch` namespace, but we moved them to c10 dispatch to facilitate custom
  1853. # class usage. We add these lines here to preserve backward compatibility.
  1854. quantized_lstm = ops.aten.quantized_lstm
  1855. quantized_gru = ops.aten.quantized_gru
  1856. # Import experimental masked operations support. See
  1857. # [RFC-0016](https://github.com/pytorch/rfcs/pull/27) for more
  1858. # information.
  1859. from torch import masked as masked
  1860. # Import removed ops with error message about removal
  1861. from torch._linalg_utils import ( # type: ignore[misc]
  1862. _symeig as symeig,
  1863. eig,
  1864. lstsq,
  1865. matrix_rank,
  1866. solve,
  1867. )
  1868. from torch.utils.dlpack import from_dlpack, to_dlpack
  1869. class _TorchCompileInductorWrapper:
  1870. compiler_name = "inductor"
  1871. def __init__(self, mode, options, dynamic):
  1872. from torch._inductor.compiler_bisector import CompilerBisector
  1873. self.config: dict[str, _Any] = {}
  1874. self.dynamic = dynamic
  1875. self.apply_mode(mode)
  1876. self.apply_options(options)
  1877. self.apply_options(CompilerBisector.get_config_change("inductor"))
  1878. cuda_version = None
  1879. if hasattr(torch, "version"):
  1880. from torch.torch_version import TorchVersion
  1881. cuda_version = TorchVersion(getattr(torch.version, "cuda", "0.0"))
  1882. if self.config.get("triton.cudagraphs", False) and (
  1883. (cuda_version and cuda_version < "12.6")
  1884. or not profiler_allow_cudagraph_cupti_lazy_reinit_cuda12()
  1885. ):
  1886. os.environ["DISABLE_CUPTI_LAZY_REINIT"] = "1"
  1887. # FIXME: CUDA Graph does not work well with CUPTI teardown.
  1888. # 1) crashes on 1st lazy CUPTI re-init after teardown (CUDA 11)
  1889. # 2) crashes on 2nd non-lazy CUPTI re-init after teardown (CUDA 12)
  1890. # Workaround: turn off CUPTI teardown when using CUDA Graphs.
  1891. os.environ["TEARDOWN_CUPTI"] = "0"
  1892. def __eq__(self, other):
  1893. return (
  1894. isinstance(other, _TorchCompileInductorWrapper)
  1895. and self.config == other.config
  1896. and self.dynamic == other.dynamic
  1897. )
  1898. def apply_mode(self, mode: str | None):
  1899. if mode and mode != "default":
  1900. from torch._inductor import list_mode_options
  1901. self.apply_options(list_mode_options(mode, self.dynamic))
  1902. def apply_options(self, options: dict[str, _Any] | None):
  1903. if not options:
  1904. return
  1905. from torch._inductor import config
  1906. current_config: dict[str, _Any] = config.get_config_copy()
  1907. for key, val in options.items():
  1908. attr_name = key.replace("-", "_")
  1909. if attr_name not in current_config:
  1910. raise RuntimeError(
  1911. f"Unexpected optimization option {key}, known options are {list(current_config.keys())}"
  1912. )
  1913. attr_type = config.get_type(attr_name) # type: ignore[attr-defined]
  1914. # Subscriptable generic types don't support isinstance so skip the type
  1915. # check. There doesn't seem to be a good way of checking membership without
  1916. # 3rd party libraries.
  1917. if _get_origin(attr_type) is None:
  1918. if not isinstance(val, attr_type):
  1919. val_type_str = type(val).__name__
  1920. expected_type_str = type(current_config[attr_name]).__name__
  1921. raise RuntimeError(
  1922. f"Unexpected type of attr {key}, got {val_type_str} should be {expected_type_str}"
  1923. )
  1924. self.config[attr_name] = val
  1925. def __call__(self, model_, inputs_):
  1926. from torch._inductor.compile_fx import compile_fx
  1927. return compile_fx(model_, inputs_, config_patches=self.config)
  1928. def get_compiler_config(self):
  1929. from torch._inductor.compile_fx import get_patched_config_dict
  1930. return get_patched_config_dict(config_patches=self.config)
  1931. def reset(self):
  1932. from torch._inductor import config
  1933. if "triton.cudagraphs" in self.config or config.triton.cudagraphs:
  1934. if self.config.get("triton.cudagraphs", True):
  1935. from torch._inductor.cudagraph_trees import reset_cudagraph_trees
  1936. reset_cudagraph_trees()
  1937. class _TorchCompileAOTInductorWrapper(_TorchCompileInductorWrapper):
  1938. compiler_name = "aotinductor"
  1939. def __init__(self, mode, options, dynamic):
  1940. super().__init__(mode, options, dynamic)
  1941. self.apply_options({"cpp_wrapper": True})
  1942. self.apply_options({"aot_inductor.package": True})
  1943. def __call__(self, model_, inputs_):
  1944. from contextlib import nullcontext
  1945. from unittest import mock
  1946. from torch._guards import detect_fake_mode
  1947. from torch._inductor.virtualized import V
  1948. fake_mode = detect_fake_mode(inputs_)
  1949. ctx = (
  1950. mock.patch.object(fake_mode, "allow_non_fake_inputs", True)
  1951. if fake_mode
  1952. else nullcontext()
  1953. )
  1954. with (
  1955. V.set_aot_compilation(True),
  1956. ctx,
  1957. torch._inductor.config.patch("enable_autograd_for_aot", True),
  1958. ):
  1959. return super().__call__(model_, inputs_)
  1960. class _TorchCompileWrapper:
  1961. def __init__(self, backend, mode, options, dynamic):
  1962. from torch._dynamo.backends.registry import lookup_backend
  1963. if isinstance(backend, str):
  1964. self.compiler_name = backend
  1965. elif hasattr(backend, "__name__"):
  1966. self.compiler_name = backend.__name__
  1967. else:
  1968. self.compiler_name = str(backend)
  1969. self.dynamic = dynamic
  1970. self.compiler_fn = lookup_backend(backend)
  1971. self.kwargs = {}
  1972. # only pass the args if they non-empty
  1973. if mode and mode != "default":
  1974. self.kwargs["mode"] = mode
  1975. if options:
  1976. self.kwargs["options"] = options
  1977. def __eq__(self, other):
  1978. return (
  1979. isinstance(other, _TorchCompileWrapper)
  1980. and self.compiler_fn == other.compiler_fn
  1981. and self.kwargs == other.kwargs
  1982. and self.dynamic == other.dynamic
  1983. )
  1984. def __call__(self, model_, inputs_):
  1985. return self.compiler_fn(model_, inputs_, **self.kwargs)
  1986. def reset(self):
  1987. if hasattr(self.compiler_fn, "reset"):
  1988. self.compiler_fn.reset()
  1989. _InputT = _ParamSpec("_InputT")
  1990. _RetT = _TypeVar("_RetT")
  1991. @_overload
  1992. def compile(
  1993. model: _Callable[_InputT, _RetT],
  1994. *,
  1995. fullgraph: builtins.bool = False,
  1996. dynamic: builtins.bool | None = None,
  1997. backend: str | _Callable = "inductor",
  1998. mode: str | None = None,
  1999. options: dict[str, str | builtins.int | builtins.bool | _Callable] | None = None,
  2000. disable: builtins.bool = False,
  2001. ) -> _Callable[_InputT, _RetT]: ...
  2002. @_overload
  2003. def compile(
  2004. model: None = None,
  2005. *,
  2006. fullgraph: builtins.bool = False,
  2007. dynamic: builtins.bool | None = None,
  2008. backend: str | _Callable = "inductor",
  2009. mode: str | None = None,
  2010. options: dict[str, str | builtins.int | builtins.bool | _Callable] | None = None,
  2011. disable: builtins.bool = False,
  2012. ) -> _Callable[[_Callable[_InputT, _RetT]], _Callable[_InputT, _RetT]]: ...
  2013. def compile(
  2014. model: _Callable[_InputT, _RetT] | None = None,
  2015. *,
  2016. fullgraph: builtins.bool = False,
  2017. dynamic: builtins.bool | None = None,
  2018. backend: str | _Callable = "inductor",
  2019. mode: str | None = None,
  2020. options: dict[str, str | builtins.int | builtins.bool | _Callable] | None = None,
  2021. disable: builtins.bool = False,
  2022. ) -> (
  2023. _Callable[[_Callable[_InputT, _RetT]], _Callable[_InputT, _RetT]]
  2024. | _Callable[_InputT, _RetT]
  2025. ):
  2026. """
  2027. Optimizes given model/function using TorchDynamo and specified backend.
  2028. If you are compiling an :class:`torch.nn.Module`, you can also use :meth:`torch.nn.Module.compile`
  2029. to compile the module inplace without changing its structure.
  2030. Concretely, for every frame executed within the compiled region, we will attempt
  2031. to compile it and cache the compiled result on the code object for future
  2032. use. A single frame may be compiled multiple times if previous compiled
  2033. results are not applicable for subsequent calls (this is called a "guard
  2034. failure"), you can use TORCH_LOGS=guards to debug these situations.
  2035. Multiple compiled results can be associated with a frame up to
  2036. ``torch._dynamo.config.recompile_limit``, which defaults to 8; at which
  2037. point we will fall back to eager. Note that compile caches are per
  2038. *code object*, not frame; if you dynamically create multiple copies of a
  2039. function, they will all share the same code cache.
  2040. Args:
  2041. model (Callable or None): Module/function to optimize
  2042. fullgraph (bool): If False (default), torch.compile attempts to discover compilable regions
  2043. in the function that it will optimize. If True, then we require that the entire function be
  2044. capturable into a single graph. If this is not possible (that is, if there are graph breaks),
  2045. then this will raise an error. This also opts into unbacked semantics, notably it will turn on
  2046. capture_scalar_outputs and capture_dynamic_output_shape_ops on by default.
  2047. dynamic (bool or None): Use dynamic shape tracing. When this is True, we will up-front attempt
  2048. to generate a kernel that is as dynamic as possible to avoid recompilations when
  2049. sizes change. This may not always work as some operations/optimizations will
  2050. force specialization; use TORCH_LOGS=dynamic to debug overspecialization.
  2051. When this is False, we will NEVER generate dynamic kernels, we will always specialize.
  2052. By default (None), we automatically detect if dynamism has occurred and compile a more
  2053. dynamic kernel upon recompile.
  2054. backend (str or Callable): backend to be used
  2055. - "inductor" is the default backend, which is a good balance between performance and overhead
  2056. - Non experimental in-tree backends can be seen with `torch._dynamo.list_backends()`
  2057. - Experimental or debug in-tree backends can be seen with `torch._dynamo.list_backends(None)`
  2058. - To register an out-of-tree custom backend:
  2059. https://pytorch.org/docs/main/torch.compiler_custom_backends.html#registering-custom-backends
  2060. mode (str): Can be either "default", "reduce-overhead", "max-autotune" or "max-autotune-no-cudagraphs"
  2061. - "default" is the default mode, which is a good balance between performance and overhead
  2062. - "reduce-overhead" is a mode that reduces the overhead of python with CUDA graphs,
  2063. useful for small batches. Reduction of overhead can come at the cost of more memory
  2064. usage, as we will cache the workspace memory required for the invocation so that we
  2065. do not have to reallocate it on subsequent runs. Reduction of overhead is not guaranteed
  2066. to work; today, we only reduce overhead for CUDA only graphs which do not mutate inputs.
  2067. There are other circumstances where CUDA graphs are not applicable; use TORCH_LOG=perf_hints
  2068. to debug.
  2069. - "max-autotune" is a mode that leverages Triton or template based matrix multiplications
  2070. on supported devices and Triton based convolutions on GPU.
  2071. It enables CUDA graphs by default on GPU.
  2072. - "max-autotune-no-cudagraphs" is a mode similar to "max-autotune" but without CUDA graphs
  2073. - To see the exact configs that each mode sets you can call `torch._inductor.list_mode_options()`
  2074. options (dict): A dictionary of options to pass to the backend. Some notable ones to try out are
  2075. - `epilogue_fusion` which fuses pointwise ops into templates. Requires `max_autotune` to also be set
  2076. - `max_autotune` which will profile to pick the best matmul configuration
  2077. - `fallback_random` which is useful when debugging accuracy issues
  2078. - `shape_padding` which pads matrix shapes to better align loads on GPUs especially for tensor cores
  2079. - `triton.cudagraphs` which will reduce the overhead of python with CUDA graphs
  2080. - `trace.enabled` which is the most useful debugging flag to turn on
  2081. - `trace.graph_diagram` which will show you a picture of your graph after fusion
  2082. - `guard_filter_fn` that controls which dynamo guards are saved with compilations.
  2083. This is an unsafe feature and there is no backward compatibility guarantee provided
  2084. for dynamo guards as data types.
  2085. For stable helper functions to use, see the documentations in `torch.compiler`, for example:
  2086. - `torch.compiler.skip_guard_on_inbuilt_nn_modules_unsafe`
  2087. - `torch.compiler.skip_guard_on_all_nn_modules_unsafe`
  2088. - `torch.compiler.keep_tensor_guards_unsafe`
  2089. - For inductor you can see the full list of configs that it supports by calling `torch._inductor.list_options()`
  2090. disable (bool): Turn torch.compile() into a no-op for testing
  2091. Example::
  2092. @torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)
  2093. def foo(x):
  2094. return torch.sin(x) + torch.cos(x)
  2095. """
  2096. import sysconfig
  2097. _C._log_api_usage_once("torch.compile")
  2098. if sys.version_info >= (3, 15):
  2099. raise RuntimeError("torch.compile is not supported on Python 3.15+")
  2100. elif sysconfig.get_config_var("Py_GIL_DISABLED") == 1 and sys.version_info < (
  2101. 3,
  2102. 13,
  2103. 3,
  2104. ):
  2105. raise RuntimeError(
  2106. "torch.compile is not supported on Python < 3.13.3 built with GIL disabled. "
  2107. "Please use Python 3.13.3+."
  2108. )
  2109. # Decorator mode
  2110. if model is None:
  2111. def fn(model: _Callable[_InputT, _RetT]) -> _Callable[_InputT, _RetT]:
  2112. if model is None:
  2113. raise RuntimeError("Model can't be None")
  2114. return compile( # pyrefly: ignore # no-matching-overload
  2115. model,
  2116. fullgraph=fullgraph,
  2117. dynamic=dynamic,
  2118. backend=backend,
  2119. mode=mode,
  2120. options=options,
  2121. disable=disable,
  2122. )
  2123. return fn
  2124. if mode is not None and options is not None:
  2125. raise RuntimeError(
  2126. "Either mode or options can be specified, but both can't be specified at the same time."
  2127. )
  2128. if mode is None and options is None:
  2129. mode = "default"
  2130. from torch._inductor.compiler_bisector import CompilerBisector
  2131. if bisect_backend := CompilerBisector.get_backend():
  2132. import torch._inductor.config as inductor_config
  2133. # don't override the backend for use cases like vllm
  2134. # which leverages their custom backend.
  2135. if not (
  2136. inductor_config.test_configs.bisect_keep_custom_backend_for_inductor
  2137. and bisect_backend == "inductor"
  2138. and not isinstance(backend, str)
  2139. ):
  2140. backend = bisect_backend
  2141. guard_filter_fn = None
  2142. use_aoti = False
  2143. if options and isinstance(options, dict):
  2144. guard_filter_fn = options.pop("guard_filter_fn", None)
  2145. use_aoti = options.pop("use_aoti", False)
  2146. if torch.compiler.is_exporting():
  2147. warnings.warn(
  2148. "You are calling torch.compile inside torch.export region. "
  2149. "To capture an useful graph, we will implicitly switch to torch.compile(backend=eager)",
  2150. stacklevel=2,
  2151. )
  2152. from torch._higher_order_ops.utils import setup_compilation_env
  2153. # Create wrapper that always uses eager backend during export
  2154. def export_wrapped_fn(*args, **kwargs):
  2155. with setup_compilation_env() as backend: # type: ignore[attr-defined]
  2156. # Force eager backend regardless of original backend
  2157. backend_wrapper = _TorchCompileWrapper(backend, mode, options, dynamic)
  2158. return torch._dynamo.optimize(
  2159. backend=backend_wrapper,
  2160. nopython=fullgraph,
  2161. dynamic=dynamic,
  2162. disable=disable,
  2163. guard_filter_fn=guard_filter_fn,
  2164. # pyrefly: ignore [bad-argument-type]
  2165. )(model)(*args, **kwargs)
  2166. return export_wrapped_fn
  2167. if backend == "inductor":
  2168. if use_aoti:
  2169. backend = _TorchCompileAOTInductorWrapper(mode, options, dynamic)
  2170. else:
  2171. backend = _TorchCompileInductorWrapper(mode, options, dynamic)
  2172. else:
  2173. backend = _TorchCompileWrapper(backend, mode, options, dynamic)
  2174. return torch._dynamo.optimize(
  2175. backend=backend,
  2176. nopython=fullgraph,
  2177. dynamic=dynamic,
  2178. disable=disable,
  2179. guard_filter_fn=guard_filter_fn,
  2180. )(model) # type: ignore[return-value]
  2181. def _register_device_module(device_type, module):
  2182. r"""Register an external runtime module of the specific :attr:`device_type`
  2183. supported by torch.
  2184. After the :attr:`module` is registered correctly, the user can refer
  2185. the external runtime module as part of torch with attribute torch.xxx.
  2186. """
  2187. # Make sure the device_type represent a supported device type for torch.
  2188. device_type = torch.device(device_type).type
  2189. m = sys.modules[__name__]
  2190. if hasattr(m, device_type):
  2191. raise RuntimeError(
  2192. f"The runtime module of '{device_type}' has already "
  2193. f"been registered with '{getattr(m, device_type)}'"
  2194. )
  2195. setattr(m, device_type, module)
  2196. torch_module_name = ".".join([__name__, device_type])
  2197. sys.modules[torch_module_name] = module
  2198. from torch import (
  2199. export as export,
  2200. func as func,
  2201. library as library,
  2202. return_types as return_types,
  2203. )
  2204. from torch._higher_order_ops import cond as cond, while_loop as while_loop
  2205. from torch.func import vmap as vmap
  2206. if not TYPE_CHECKING:
  2207. from torch import _meta_registrations
  2208. # Enable CUDA Sanitizer
  2209. if "TORCH_CUDA_SANITIZER" in os.environ:
  2210. import torch.cuda._sanitizer as csan
  2211. csan.enable_cuda_sanitizer()
  2212. # Populate magic methods on SymInt and SymFloat
  2213. import torch.fx.experimental.sym_node
  2214. from torch import fx as fx
  2215. # Register MPS specific decomps
  2216. torch.backends.mps._init()
  2217. from torch import compiler as compiler
  2218. class _TritonLibrary:
  2219. lib = torch.library.Library("triton", "DEF")
  2220. ops_table: dict[tuple[str, str], _Callable] = {}
  2221. @classmethod
  2222. def registerOp(cls, op_key, full_schema, op_impl, dispatch_key):
  2223. if (op_key, dispatch_key) not in cls.ops_table:
  2224. cls.lib.define(full_schema)
  2225. cls.lib.impl("triton::" + op_key, op_impl, dispatch_key)
  2226. cls.ops_table[(op_key, dispatch_key)] = op_impl
  2227. return cls.ops_table[(op_key, dispatch_key)]
  2228. # Deprecated attributes
  2229. _deprecated_attrs = {
  2230. "has_mps": torch.backends.mps.is_built,
  2231. "has_cuda": torch.backends.cuda.is_built,
  2232. "has_cudnn": torch.backends.cudnn.is_available,
  2233. "has_mkldnn": torch.backends.mkldnn.is_available,
  2234. }
  2235. if TYPE_CHECKING:
  2236. # Import the following modules during type checking to enable code intelligence features,
  2237. # such as auto-completion in tools like pylance, even when these modules are not explicitly
  2238. # imported in user code.
  2239. from torch import (
  2240. _dynamo as _dynamo,
  2241. _inductor as _inductor,
  2242. _subclasses as _subclasses,
  2243. onnx as onnx,
  2244. )
  2245. else:
  2246. _lazy_modules = {
  2247. "_dynamo",
  2248. "_inductor",
  2249. "_export",
  2250. # ONNX must be imported after _dynamo, _ops, _subclasses, fx, func and jit
  2251. "onnx",
  2252. }
  2253. def __getattr__(name):
  2254. # Deprecated attrs
  2255. replacement = _deprecated_attrs.get(name)
  2256. if replacement is not None:
  2257. import warnings
  2258. warnings.warn(
  2259. f"'{name}' is deprecated, please use '{replacement.__module__}.{replacement.__name__}()'",
  2260. stacklevel=2,
  2261. )
  2262. return replacement()
  2263. # Lazy modules
  2264. if name in _lazy_modules:
  2265. return importlib.import_module(f".{name}", __name__)
  2266. raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
  2267. @functools.cache
  2268. def get_device_module(device: torch.device | str | None = None):
  2269. """
  2270. Returns the module associated with a given device(e.g., torch.device('cuda'), "mtia:0", "xpu", ...).
  2271. If no device is given, return the module for the current accelerator or CPU if none is present.
  2272. """
  2273. if isinstance(device, torch.device):
  2274. device_module_name = device.type
  2275. elif isinstance(device, str):
  2276. device_module_name = torch.device(device).type
  2277. elif device is None:
  2278. # Using default accelerator type. If no accelerator is available, it automatically returns CPU device.
  2279. device_module_name = torch._C._get_accelerator().type
  2280. else:
  2281. raise RuntimeError(
  2282. f"Invalid value of device '{device}', expect torch.device, str, or None"
  2283. )
  2284. device_module = getattr(torch, device_module_name, None)
  2285. if device_module is None:
  2286. raise RuntimeError(
  2287. f"Device '{device_module_name}' does not have a corresponding module registered as 'torch.{device_module_name}'."
  2288. )
  2289. return device_module
  2290. def _constrain_as_size(
  2291. symbol,
  2292. min: builtins.int | None = None,
  2293. max: builtins.int | None = None,
  2294. ):
  2295. """
  2296. This indicates that a given int is size-like, and can be used in any context where a size is expected.
  2297. You will typically use this when reading out integers from Tensors, e.g., max.item() or lengths.tolist()
  2298. which then need to be used as tensor constructors. Providing these assertions to PyTorch can help resolve
  2299. GuardOnDataDependentSymNode errors upon export, since we cannot guard on unbacked SymInts.
  2300. This function has unusual semantics in some circumstances in framework
  2301. code, we will treat this int as >= 2 (when we do a size-oblivious guard).
  2302. This makes it easier to use the unbacked int in size contexts,
  2303. as we will often attempt to guard on a size being zero/one
  2304. (e.g., when computing the contiguity of a tensor, or testing if
  2305. broadcasting can occur), which will not work on unbacked SymInts.
  2306. However, if we conservatively assume that the size is not zero/one, we will
  2307. end up with a graph that will still work even if the size is zero/one.
  2308. For more details, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit
  2309. ```
  2310. """
  2311. torch.sym_constrain_range_for_size(symbol, min=min, max=max)
  2312. from torch import _logging
  2313. _logging._init_logs()
  2314. def _import_device_backends():
  2315. """
  2316. Leverage the Python plugin mechanism to load out-of-the-tree device extensions.
  2317. See this RFC: https://github.com/pytorch/pytorch/issues/122468
  2318. """
  2319. from importlib.metadata import entry_points
  2320. group_name = "torch.backends"
  2321. backend_extensions = entry_points(group=group_name)
  2322. for backend_extension in backend_extensions:
  2323. try:
  2324. # Load the extension
  2325. entrypoint = backend_extension.load()
  2326. # Call the entrypoint
  2327. entrypoint()
  2328. except Exception as err:
  2329. raise RuntimeError(
  2330. f"Failed to load the backend extension: {backend_extension.name}. "
  2331. f"You can disable extension auto-loading with TORCH_DEVICE_BACKEND_AUTOLOAD=0."
  2332. ) from err
  2333. def _is_device_backend_autoload_enabled() -> builtins.bool:
  2334. """
  2335. Whether autoloading out-of-the-tree device extensions is enabled.
  2336. The switch depends on the value of the environment variable
  2337. `TORCH_DEVICE_BACKEND_AUTOLOAD`.
  2338. Returns:
  2339. bool: Whether to enable autoloading the extensions. Enabled by default.
  2340. Examples:
  2341. >>> torch._is_device_backend_autoload_enabled()
  2342. True
  2343. """
  2344. # enabled by default
  2345. return os.getenv("TORCH_DEVICE_BACKEND_AUTOLOAD", "1") == "1"
  2346. def _as_tensor_fullprec(t):
  2347. """
  2348. Like torch.as_tensor, but when given Python data types it will keep
  2349. them in full precision. Used for calling convention for Dynamo.
  2350. """
  2351. ty = type(t)
  2352. if ty is builtins.float:
  2353. return torch.as_tensor(t, dtype=torch.float64)
  2354. elif ty is builtins.int:
  2355. return torch.as_tensor(t, dtype=torch.int64)
  2356. else:
  2357. return torch.as_tensor(t)
  2358. # `_import_device_backends` should be kept at the end to ensure
  2359. # all the other functions in this module that may be accessed by
  2360. # an autoloaded backend are defined
  2361. if _is_device_backend_autoload_enabled():
  2362. _import_device_backends()