tunable.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822
  1. r"""
  2. This module exposes a TunableOp interface.
  3. Some operations, such as GEMMs, could be implemented using more than one library
  4. or more than one technique. For example, a GEMM could be implemented for CUDA or
  5. ROCm using either the blas or blasLt libraries. Further, ROCm's rocblas and
  6. hipblaslt libraries allow the user to query for all possible algorithms and then
  7. choose one. How does one know which implementation is the fastest and should be
  8. chosen? That's what TunableOp provides.
  9. Enabling TunableOp and Tuning Separately
  10. ========================================
  11. The TunableOp feature is enabled separately from enabling the tuning phase
  12. itself. Enabling TunableOp means that PyTorch will replace any standard
  13. operators with their Tunable implementations. Any call to a TunableOp first
  14. checks whether it has already been tuned for the given operator inputs. If so,
  15. it will immediately call the tuned operation; no further tuning will take place
  16. even when the tuning setting is enabled. Instead if no tuning result is found,
  17. and tuning is enabled, the TunableOp will benchmark every registered
  18. implementation of that operator for the given set of inputs and select the
  19. fastest.
  20. File Input and Output
  21. =====================
  22. The first time any TunableOp is invoked, the internal database of tuned
  23. operations will be prepared by attempting to read the results from the given
  24. file. The default filename is 'tunableop_results.csv'. To support tuning when
  25. multiple GPUs are used across multiple processes, the GPU device ordinal is
  26. automatically inserted into the filename to avoid multiple processes overwriting
  27. the same file.
  28. If tuning is enabled and new tunings are discovered during the course of your
  29. workload, it will also write out to this same filename with all tunings, both
  30. the ones it read in at startup as well as the new ones found at runtime. This
  31. can be used, for example, to build up a tunings file across many workloads by
  32. reusing the same file. The output file is automatically created when the
  33. application terminates. This behavior can be controlled by the C++ and Python
  34. APIs but not the environment variables.
  35. Assuming you specified a filename, you'll end up with a CSV file with contents
  36. like so::
  37. Validator,PT_VERSION,2.2.0
  38. Validator,ROCM_VERSION,6.0.0.0-12969-1544e39
  39. Validator,HIPBLASLT_VERSION,0.6.0-a9c5cc7
  40. Validator,ROCBLAS_VERSION,4.0.0-72e57364-dirty
  41. GemmTunableOp_float_NT,nt_25088_4096_64,Gemm_Hipblaslt_1219,1.262
  42. GemmTunableOp_float_NT,nt_4096_4096_64,Gemm_Rocblas_1216,0.033
  43. Note the "Validator" lines. If you change a library version, or ROCm version, or
  44. PyTorch version, TunableOp will detect this and reject the tunings file because
  45. the prior tunings are likely affected by other software changes.
  46. The remaining lines are the tuned solutions for each TunableOp encountered
  47. during your execution. Each line consists of 4 comma-separated fields: operator
  48. name, operator parameters, solution name, and average execution time. The
  49. execution time is an optional field. The CSV file can be edited, but with
  50. caution. For example, the solution name (field 3) can be changed to "Default"
  51. and it will fall back to the original PyTorch untuned implementation. Or, in the
  52. case of ROCm's hipBLAS or hipBLASLt libraries, if you know the specific solution
  53. index you can override the solution that TunableOp selected by replacing the
  54. value. The operator name and parameters (fields 1 and 2) are internally named
  55. and should not be modified. In the case of GemmTunableOp, field 1 indicates the
  56. datatype and whether the inputs are transposed (T) or not (N) and field 2
  57. indicates the M, N, K input shapes.
  58. There is an option to enable verbose output but it is only recommended for
  59. debugging purposes. This will produce a lot of diagnostic messages but may be
  60. useful to see if TunableOp is being used at all. Otherwise, TunableOp is
  61. completely silent, besides file output, unless there is a warning or error
  62. during its use. The verbose option is only available by setting the environment
  63. variable PYTORCH_TUNABLEOP_VEROBSE=1.
  64. A Note on Tuning Behavior, Warmup, and Cache Effects
  65. ====================================================
  66. Tuning an operator consists of iterating through the list or registered
  67. implementations and profiling each one. The profile is established by running a
  68. single implementation in a loop multiple times and taking the average execution
  69. time. There is also an optional warmup phase prior to tuning that can help with
  70. reaching stable power states by the hardware. During tuning of a workload the
  71. various hardware caches will more likely produce hits than when not tuning.
  72. There are options for flushing the instruction cache and rotate the input tensors
  73. which might help produce a more faithful profile of the tuned operator as if the
  74. operator were run within a larger workload instead of in a tight, repetitive loop.
  75. By default, each possible solution for a given operator will be run for either
  76. 100 iterations or as many iterations that can be run within 30ms, whichever is
  77. smaller, and its average execution will be calculated. The fastest solution
  78. among all that were successfully profiled will be chosen. A profile might fail
  79. if the given solution doesn't achieve the same accuracy as the default
  80. implementation or if the solution returns an error code.
  81. Current Tunable Operators
  82. =========================
  83. TunableGemm for ROCm
  84. --------------------
  85. Currently only a TunableGemm for ROCm is implemented. Note that CUDA builds of
  86. PyTorch will function correctly when using TunableOp but the only solution
  87. available to CUDA builds is the 'Default' implementation i.e. the original
  88. cuBLAS default, now called through TunableOp. Any call to at::cuda::blas::gemm()
  89. or ::bgemm() will be routed through TunableOp when enabled. Calling gemm() for a
  90. given set of input arguments (transa, transb, m, n, k) will attempt to use the
  91. fastest available implementation across both rocblas and hipblaslt.
  92. Offline Tuning
  93. ==============
  94. Motivation
  95. ----------
  96. There are several use cases for offline tuning.
  97. One use case involves a workload with a high-memory utilization, where regular tuning might lead to running out of memory.
  98. Another use case is for compute-intensive workloads. In such cases, it is more resource-efficient to collect
  99. the GEMMs for the workload once and then tune repeatedly with different tuning parameters or libraries.
  100. Workflow
  101. --------
  102. There are basically two steps:
  103. 1) Set the environment variables to collect the untuned GEMM and this will generate ``tunableop_untuned0.csv``:
  104. .. code-block:: bash
  105. export PYTORCH_TUNABLEOP_ENABLED=1
  106. export PYTORCH_TUNABLEOP_TUNING=0
  107. export PYTORCH_TUNABLEOP_RECORD_UNTUNED=1
  108. ...
  109. 2) Run a Python script that reads the ``tunableop_untuned0.csv`` and generates the ``tunableop_results0.csv``, like this:
  110. .. code-block:: python
  111. import torch.cuda.tunable as tunable
  112. import os
  113. os.putenv("PYTORCH_TUNABLEOP_ENABLED", "1")
  114. os.putenv("PYTORCH_TUNABLEOP_TUNING", "1")
  115. os.putenv("PYTORCH_TUNABLEOP_RECORD_UNTUNED", "0")
  116. tunable.tune_gemm_in_file("tunableop_untuned0.csv")
  117. It is also possible to take multiple untuned files and distribute the GEMMs for tuning to multiple GPUs
  118. within a single node. In the first step, the GEMMs are first gathered and duplicate GEMMs are eliminated.
  119. Next, the GEMMs are distributed to different GPUs for tuning. After all GEMMs are tuned, the results from
  120. all the GPUs are then gathered into a single file whose base filename has ``_full0`` appended to it
  121. (for example ``tunableop_results_full0.csv``). Finally, this new file, containing the gathered results, will be
  122. duplicated N times, once for each GPU as convenience to the user will run the workload with the tuned
  123. configuration on N GPUs.
  124. .. code-block:: python
  125. if __name__ == "__main__":
  126. num_gpus = 8 # number of GPUs that will be used during the tuning process
  127. tunable.mgpu_tune_gemm_in_file("tunableop_untuned?.csv", num_gpus)
  128. Note that the usage of the ``mgpu_tune_gemm_in_file`` API is different from its single GPU counterpart
  129. (``tune_gemm_in_file``). The body of the Python script that calls the API must be wrapped in ``main()`` as shown
  130. due to the use of concurrent futures module. The argument to ``mgpu_tune_gemm_in_file`` must contain a wild card
  131. expression (``?`` or ``*``) to generate the list of untuned files containing the GEMMs to be processed. The ``num_gpus``
  132. must between 1 and the total number of GPUs available.
  133. Tuning Context
  134. ==============
  135. The behavior of TunableOp is currently manipulated through environment
  136. variables, the C++ interface of at::cuda::tunable::getTuningContext(), or the
  137. torch.cuda.tunable python interfaces. The environment variables take precedence
  138. over any setting you manipulate using the C++ or Python APIs.
  139. Environment Variable Interface
  140. ------------------------------
  141. Environment variables are cached the first time they are read. You cannot use the
  142. environment variable interface programmatically since the settings become fixed.
  143. Use the C++ or Python APIs instead.
  144. """
  145. import concurrent.futures
  146. import glob
  147. import multiprocessing as mp
  148. import os
  149. import shutil
  150. import warnings
  151. from typing import Optional
  152. import torch
  153. __all__ = [
  154. "enable",
  155. "is_enabled",
  156. "tuning_enable",
  157. "tuning_is_enabled",
  158. "record_untuned_enable",
  159. "record_untuned_is_enabled",
  160. "set_max_tuning_duration",
  161. "get_max_tuning_duration",
  162. "set_max_tuning_iterations",
  163. "get_max_tuning_iterations",
  164. "set_filename",
  165. "get_filename",
  166. "get_results",
  167. "get_validators",
  168. "write_file_on_exit",
  169. "write_file",
  170. "read_file",
  171. "tune_gemm_in_file",
  172. "mgpu_tune_gemm_in_file",
  173. "set_rotating_buffer_size",
  174. "get_rotating_buffer_size",
  175. ]
  176. def enable(val: bool = True) -> None:
  177. r"""This is the big on/off switch for all TunableOp implementations."""
  178. torch._C._cuda_tunableop_enable(val) # type: ignore[attr-defined]
  179. def is_enabled() -> bool:
  180. r"""Returns whether the TunableOp feature is enabled."""
  181. return torch._C._cuda_tunableop_is_enabled() # type: ignore[attr-defined]
  182. def tuning_enable(val: bool = True) -> None:
  183. r"""Enable tuning of TunableOp implementations.
  184. When enabled, if a tuned entry isn't found, run the tuning step and record
  185. the entry.
  186. """
  187. torch._C._cuda_tunableop_tuning_enable(val) # type: ignore[attr-defined]
  188. def tuning_is_enabled() -> bool:
  189. r"""Returns whether TunableOp implementations can be tuned."""
  190. return torch._C._cuda_tunableop_tuning_is_enabled() # type: ignore[attr-defined]
  191. def record_untuned_enable(val: bool = True) -> None:
  192. r"""Enable recording untuned of TunableOp perations for offline tuning.
  193. When enabled, if a tuned entry isn't found, write it to the untuned file.
  194. """
  195. torch._C._cuda_record_untuned_enable(val) # type: ignore[attr-defined]
  196. def record_untuned_is_enabled() -> bool:
  197. r"""Returns whether TunableOp operations are recorded for offline tuning."""
  198. return torch._C._cuda_record_untuned_is_enabled() # type: ignore[attr-defined]
  199. def set_max_tuning_duration(duration: int) -> None:
  200. r"""Set max time in milliseconds to spend tuning a given solution.
  201. If both max tuning duration and iterations are set, the smaller of the two
  202. will be honored. At minimum 1 tuning iteration will always be run.
  203. """
  204. torch._C._cuda_tunableop_set_max_tuning_duration(duration) # type: ignore[attr-defined]
  205. def get_max_tuning_duration() -> int:
  206. r"""Get max time to spend tuning a given solution."""
  207. return torch._C._cuda_tunableop_get_max_tuning_duration() # type: ignore[attr-defined]
  208. def set_max_tuning_iterations(iterations: int) -> None:
  209. r"""Set max number of iterations to spend tuning a given solution.
  210. If both max tuning duration and iterations are set, the smaller of the two
  211. will be honored. At minimum 1 tuning iteration will always be run.
  212. """
  213. torch._C._cuda_tunableop_set_max_tuning_iterations(iterations) # type: ignore[attr-defined]
  214. def get_max_tuning_iterations() -> int:
  215. r"""Get max iterations to spend tuning a given solution."""
  216. return torch._C._cuda_tunableop_get_max_tuning_iterations() # type: ignore[attr-defined]
  217. def set_filename(filename: str, insert_device_ordinal: bool = False) -> None:
  218. r"""Set the filename to use for input/output of tuning results.
  219. If :attr:`insert_device_ordinal` is ``True`` then the current device ordinal
  220. will be added to the given filename automatically. This can be used in a
  221. 1-process-per-gpu scenario to ensure all processes write to a separate file.
  222. """
  223. torch._C._cuda_tunableop_set_filename(filename, insert_device_ordinal) # type: ignore[attr-defined]
  224. def get_filename() -> str:
  225. r"""Get the results filename."""
  226. return torch._C._cuda_tunableop_get_filename() # type: ignore[attr-defined]
  227. def get_results() -> tuple[str, str, str, float]:
  228. r"""Return all TunableOp results."""
  229. return torch._C._cuda_tunableop_get_results() # type: ignore[attr-defined]
  230. def get_validators() -> tuple[str, str]:
  231. r"""Return the TunableOp validators."""
  232. return torch._C._cuda_tunableop_get_validators() # type: ignore[attr-defined]
  233. def write_file_on_exit(val: bool) -> None:
  234. r"""During Tuning Context destruction, write file to disk.
  235. This is useful as a final flush of your results to disk if your application
  236. terminates as result of normal operation or an error. Manual flushing of
  237. your results can be achieved by manually calling ``write_file()``."""
  238. torch._C._cuda_tunableop_write_file_on_exit(val) # type: ignore[attr-defined]
  239. def write_file(filename: Optional[str] = None) -> bool:
  240. r"""Write results to a CSV file.
  241. If :attr:`filename` is not given, ``get_filename()`` is called.
  242. """
  243. if filename is None:
  244. filename = get_filename()
  245. return torch._C._cuda_tunableop_write_file(filename) # type: ignore[attr-defined]
  246. def read_file(filename: Optional[str] = None) -> bool:
  247. r"""Read results from a TunableOp CSV file.
  248. If :attr:`filename` is not given, ``get_filename()`` is called.
  249. """
  250. if filename is None:
  251. filename = get_filename()
  252. return torch._C._cuda_tunableop_read_file(filename) # type: ignore[attr-defined]
  253. def set_rotating_buffer_size(buffer_size: int) -> None:
  254. r"""Set rotating buffer size to this value in MB, if the buffer size is greater than zero.
  255. If less than zero, query L2 cache size. If equal to zero, means deactivate rotating buffer.
  256. """
  257. return torch._C._cuda_tunableop_set_rotating_buffer_size(buffer_size) # type: ignore[attr-defined]
  258. def get_rotating_buffer_size() -> int:
  259. r"""Get the rotating buffer size in kilobytes."""
  260. return torch._C._cuda_tunableop_get_rotating_buffer_size() # type: ignore[attr-defined]
  261. def tune_gemm_in_file(filename: str) -> None:
  262. r"""tune GEMM in file."""
  263. assert is_enabled()
  264. assert tuning_is_enabled()
  265. deviceid = torch.cuda.current_device()
  266. with open(filename) as file:
  267. for line in file:
  268. if line.startswith(("Gemm", "ScaledGemm")):
  269. _process_single_offline_gemm(line, deviceid)
  270. def _gather_unique_untuned_gemm_from_files(filename_pattern: str) -> set[str]:
  271. r"""Process multiple untuned results file and return a set with duplicates removed."""
  272. unique_gemm_entries = set() # set will avoid duplicates
  273. for file_path in glob.glob(filename_pattern):
  274. with open(file_path) as file:
  275. for line in file:
  276. if line.startswith(("Gemm", "ScaledGemm")):
  277. unique_gemm_entries.add(line)
  278. return unique_gemm_entries
  279. def _gather_tunableop_results() -> None:
  280. r"""Gather results from multiple tunableop results file and create a single file."""
  281. gemm_lines = set()
  282. validator_lines = []
  283. # Need to allow for the possibility that results filename was
  284. # set with the Python API instead of with environment variable.
  285. # Also possible that results filename was not set at all.
  286. # There are several test cases to check, but ultimately we
  287. # need a glob-able expression
  288. results_filename = get_filename() # Note empty string could be returned here
  289. if (
  290. results_filename is not None and results_filename != ""
  291. ): # Case were the Python API was used to set the filename
  292. dot_pos = results_filename.find(".")
  293. if dot_pos != -1 and dot_pos > 0:
  294. # Replace the character just to the left of the dot
  295. filename_pattern = (
  296. results_filename[: dot_pos - 1] + "?" + results_filename[dot_pos:]
  297. )
  298. else:
  299. filename_pattern = "" # Needed to make linter happy
  300. else: # Case where the environment variable was used to set the filename.
  301. results_filename_env = os.getenv("PYTORCH_TUNABLEOP_FILENAME")
  302. if results_filename_env is None or results_filename_env == "":
  303. filename_pattern = "tunableop_results?.csv"
  304. elif "%d" in results_filename_env:
  305. filename_pattern = results_filename_env.replace("%d", "?")
  306. else:
  307. filename_pattern = results_filename_env.replace(".", "?.")
  308. assert "?" in filename_pattern
  309. FirstFile = False
  310. matching_files = glob.glob(filename_pattern)
  311. num_matching_files = len(matching_files)
  312. for file_path in matching_files:
  313. with open(file_path) as file:
  314. for line in file:
  315. if line.startswith("Validator"):
  316. if not (FirstFile):
  317. # Only read Validator from first file
  318. validator_lines.append(line)
  319. else:
  320. gemm_lines.add(line)
  321. FirstFile = True
  322. output_file = filename_pattern.replace("?", "_full0")
  323. with open(output_file, "w") as out_file:
  324. for line in validator_lines:
  325. out_file.write(line)
  326. for line in gemm_lines:
  327. out_file.write(line)
  328. # Create num_matching_copies of the results file
  329. for i in range(1, num_matching_files):
  330. duplicate_file = output_file.replace("0", str(i))
  331. shutil.copy(output_file, duplicate_file)
  332. def _create_matrices(
  333. m: int,
  334. n: int,
  335. k: int,
  336. lda: int,
  337. ldb: int,
  338. ldc: int,
  339. transA: bool,
  340. transB: bool,
  341. dtypeA: torch.dtype,
  342. deviceid: str,
  343. dtypeB: Optional[torch.dtype] = None,
  344. randn: bool = True,
  345. subMatrix: bool = False,
  346. ) -> tuple[torch.Tensor, torch.Tensor]:
  347. r"""Helper function for _process_single_offline_gemm.
  348. Creates matrices that are then consumed by one of the Torch GEMM APIs.
  349. """
  350. # Fill parameters set for use with ScaledGEMM
  351. fillA = 0.25
  352. fillB = 0.75
  353. if dtypeB is None:
  354. dtypeB = dtypeA
  355. if subMatrix:
  356. # User reference for understanding leading dimension:
  357. # https://github.com/Reference-LAPACK/lapack/blob/master/BLAS/SRC/dgemm.f
  358. # TO DO: According to lines 108 - 133, there is no lower bound on rowsA,
  359. # but there is a restriction on rowsB. Using this formula for now as it
  360. # seems to work for all UTs.
  361. rowsA = rowsB = max(ldc, k)
  362. if randn:
  363. matA = torch.randn(rowsA, lda, dtype=dtypeA, device=deviceid)
  364. matB = torch.randn(rowsB, ldb, dtype=dtypeA, device=deviceid)
  365. else:
  366. matA = torch.full((rowsA, lda), fillA, dtype=dtypeB, device=deviceid)
  367. matB = torch.full((rowsB, ldb), fillB, dtype=dtypeB, device=deviceid)
  368. subA = matA[:k, :m].t() if transA else matA[:m, :k]
  369. subB = matB[:n, :k].t() if transB else matB[:k, :n]
  370. return subA, subB
  371. else:
  372. if randn:
  373. matA = (
  374. torch.rand(k, m, dtype=dtypeA, device=deviceid).t()
  375. if transA
  376. else torch.rand(m, k, dtype=dtypeA, device=deviceid)
  377. )
  378. matB = (
  379. torch.rand(n, k, dtype=dtypeB, device=deviceid).t()
  380. if transB
  381. else torch.rand(k, n, dtype=dtypeB, device=deviceid)
  382. )
  383. else:
  384. matA = (
  385. torch.full((k, m), fillA, dtype=dtypeA, device=deviceid).t()
  386. if transA
  387. else torch.full((m, k), fillA, dtype=dtypeA, device=deviceid)
  388. )
  389. matB = (
  390. torch.full((n, k), fillB, dtype=dtypeB, device=deviceid).t()
  391. if transB
  392. else torch.full((k, n), fillB, dtype=dtypeB, device=deviceid)
  393. )
  394. return matA, matB
  395. def _create_batch_matrices(
  396. m: int,
  397. n: int,
  398. k: int,
  399. b: int,
  400. lda: int,
  401. ldb: int,
  402. ldc: int,
  403. transA: bool,
  404. transB: bool,
  405. dtype: torch.dtype,
  406. deviceid: str,
  407. subMatrix: bool = False,
  408. ) -> tuple[torch.Tensor, torch.Tensor]:
  409. r"""Helper function for _process_single_offline_gemm.
  410. Creates batch matrices that are then consumed by one of the Torch GEMM APIs.
  411. Similar to _create_matrices but for 3D batch matrices.
  412. """
  413. if subMatrix:
  414. # User reference for understanding leading dimension:
  415. # https://github.com/Reference-LAPACK/lapack/blob/master/BLAS/SRC/dgemm.f
  416. # TO DO: According to lines 108 - 133, there is no lower bound on rowsA,
  417. # but there is a restriction on rowsB. Using this formula for now as it
  418. # seems to work for all UTs.
  419. rowsA = rowsB = max(ldc, k)
  420. matA = torch.randn(b, rowsA, lda, dtype=dtype, device=deviceid)
  421. matB = torch.randn(b, rowsB, ldb, dtype=dtype, device=deviceid)
  422. subA = matA[:b, :k, :m].transpose(1, 2) if transA else matA[:b, :m, :k]
  423. subB = matB[:b, :n, :k].transpose(1, 2) if transB else matB[:b, :k, :n]
  424. return subA, subB
  425. else:
  426. matA = (
  427. torch.rand(b, k, m, dtype=dtype, device=deviceid)
  428. if transA
  429. else torch.rand(b, m, k, dtype=dtype, device=deviceid)
  430. )
  431. matB = (
  432. torch.rand(b, n, k, dtype=dtype, device=deviceid)
  433. if transB
  434. else torch.rand(b, k, n, dtype=dtype, device=deviceid)
  435. )
  436. matA = matA.transpose(1, 2) if transA else matA
  437. matB = matB.transpose(1, 2) if transB else matB
  438. return matA, matB
  439. def _process_single_offline_gemm(untuned_gemm_line: str, gpu_id: int) -> None:
  440. r"""Process a single untuned GEMM."""
  441. deviceid = "cuda:" + str(gpu_id)
  442. dtype_dict = {
  443. "float": torch.float32,
  444. "tf32": torch.float32,
  445. "double": torch.float64,
  446. "BFloat16": torch.bfloat16,
  447. "Half": torch.half,
  448. "c10::complex<double>": torch.complex128,
  449. "c10::complex<float>": torch.complex64,
  450. "Float8_e4m3fn": torch.float8_e4m3fn,
  451. "Float8_e5m2": torch.float8_e5m2,
  452. "Float8_e4m3fnuz": torch.float8_e4m3fnuz,
  453. "Float8_e5m2fnuz": torch.float8_e5m2fnuz,
  454. }
  455. untuned_gemm = untuned_gemm_line.strip().split(",")[:]
  456. underscore_count = untuned_gemm[0].count("_")
  457. # Initialize dtype to make linter happy
  458. dtype = None
  459. dtypeA = None
  460. dtypeB = None
  461. dtypeC = None
  462. # Extract BLAS parameters
  463. if underscore_count == 2:
  464. [op_sig, data_type, layout] = untuned_gemm[0].split("_")
  465. transB = layout[0] == "T"
  466. transA = layout[1] == "T"
  467. dtype = dtype_dict.get(data_type)
  468. if data_type == "tf32":
  469. # User must still set HIPBLASLT_ALLOW_TF32=1
  470. torch.backends.cuda.matmul.allow_tf32 = True
  471. else:
  472. torch.backends.cuda.matmul.allow_tf32 = False
  473. else: # ScaledGEMM
  474. count = untuned_gemm[0].count("_")
  475. assert count in [6, 7]
  476. untuned_gemm_temp = untuned_gemm[0].split("_")
  477. # dtypeC = might not be FP8 type, keep track
  478. # of the the number of underscores
  479. op_sig = untuned_gemm_temp[0]
  480. data_typeA = untuned_gemm_temp[1] + "_" + untuned_gemm_temp[2]
  481. data_typeB = untuned_gemm_temp[3] + "_" + untuned_gemm_temp[4]
  482. if count == 7:
  483. data_typeC = untuned_gemm_temp[5] + "_" + untuned_gemm_temp[6]
  484. else:
  485. data_typeC = untuned_gemm_temp[5]
  486. transB = untuned_gemm_temp[count][0] == "T"
  487. transA = untuned_gemm_temp[count][1] == "T"
  488. dtypeA = dtype_dict.get(data_typeA)
  489. dtypeB = dtype_dict.get(data_typeB)
  490. dtypeC = dtype_dict.get(data_typeC)
  491. untuned_gemm_temp = untuned_gemm[1].split("_")
  492. [n, m, k] = [int(g) for g in untuned_gemm_temp[1:4]]
  493. if op_sig == "GemmStridedBatchedTunableOp":
  494. assert untuned_gemm_temp[6] == "ld"
  495. [ldb, lda, ldc] = [int(g) for g in untuned_gemm_temp[7:10]]
  496. else:
  497. assert untuned_gemm_temp[4] == "ld"
  498. [ldb, lda, ldc] = [int(g) for g in untuned_gemm_temp[5:8]]
  499. # Detect subMatrix case
  500. if all(item in [n, m, k] for item in [lda, ldb, ldc]):
  501. subMatrix = False
  502. else:
  503. subMatrix = True
  504. if op_sig == "GemmTunableOp":
  505. # Warnings for unsupported cases:
  506. if m == 1 or n == 1 or k == 1:
  507. if (not transA) and (not transB):
  508. pass # case is supported
  509. elif transA and n == 1:
  510. pass # case is supported
  511. else:
  512. warnings.warn(
  513. "Offline tuning is not supported for this GEMM. Use online tuning instead. "
  514. + f"Skipped tuning for: {untuned_gemm[1]}"
  515. )
  516. return
  517. # Resolve linter issue
  518. if dtype is None or not isinstance(dtype, torch.dtype):
  519. raise TypeError(f"dtype must be a torch.dtype, but got {dtype}")
  520. matA, matB = _create_matrices(
  521. m, n, k, lda, ldb, ldc, transA, transB, dtype, deviceid, subMatrix=subMatrix
  522. )
  523. torch.mm(matA, matB)
  524. elif op_sig == "GemmStridedBatchedTunableOp":
  525. # Warnings for unsupported cases:
  526. if m == 1 or n == 1 or k == 1:
  527. warnings.warn(
  528. "Offline tuning is not support for this GEMM. Use online tuning instead. "
  529. + f"Skipped tuning for: {untuned_gemm[1]}"
  530. )
  531. return
  532. [b] = [int(g) for g in untuned_gemm_temp[5:6]]
  533. # Resolve linter issue
  534. if dtype is None or not isinstance(dtype, torch.dtype):
  535. raise TypeError(f"dtype must be a torch.dtype, but got {dtype}")
  536. matA, matB = _create_batch_matrices(
  537. m,
  538. n,
  539. k,
  540. b,
  541. lda,
  542. ldb,
  543. ldc,
  544. transA,
  545. transB,
  546. dtype,
  547. deviceid,
  548. subMatrix=subMatrix,
  549. )
  550. torch.bmm(matA, matB)
  551. elif op_sig == "ScaledGemmTunableOp":
  552. # Only combination supported by PyTorch
  553. assert transB is True
  554. assert transA is False
  555. # Resolve linter issue
  556. if dtypeA is None or not isinstance(dtypeA, torch.dtype):
  557. raise TypeError(f"dtype must be a torch.dtype, but got {dtypeA}")
  558. matA, matB = _create_matrices(
  559. m,
  560. n,
  561. k,
  562. lda,
  563. ldb,
  564. ldc,
  565. transA,
  566. transB,
  567. dtypeA,
  568. deviceid,
  569. dtypeB=dtypeB,
  570. randn=False,
  571. subMatrix=subMatrix,
  572. )
  573. assert untuned_gemm_temp[8] == "rw"
  574. if untuned_gemm_temp[9] == "1":
  575. rowwise = True
  576. else:
  577. rowwise = False
  578. if rowwise:
  579. scaleA = (
  580. torch.ones((1, m), device=deviceid)
  581. if transA
  582. else torch.ones((m, 1), device=deviceid)
  583. )
  584. scaleB = (
  585. torch.ones((1, n), device=deviceid)
  586. if transB
  587. else torch.ones((n, 1), device=deviceid)
  588. )
  589. else:
  590. scaleA = torch.tensor(0.8, device=deviceid)
  591. scaleB = torch.tensor(0.9, device=deviceid)
  592. assert untuned_gemm_temp[10] == "bias"
  593. if untuned_gemm_temp[11] == "None": # no bias vector
  594. torch._scaled_mm(
  595. matA, matB, scale_a=scaleA, scale_b=scaleB, out_dtype=dtypeC
  596. )
  597. else: # bias vector present
  598. fillbias = 0.10
  599. bias_dtype = dtype_dict.get(untuned_gemm_temp[11])
  600. bias = (
  601. torch.full((n,), fillbias, dtype=bias_dtype, device=deviceid)
  602. if transB
  603. else torch.full((m,), fillbias, dtype=bias_dtype, device=deviceid)
  604. )
  605. torch._scaled_mm(
  606. matA, matB, scale_a=scaleA, scale_b=scaleB, out_dtype=dtypeC, bias=bias
  607. )
  608. elif op_sig == "GemmAndBiasTunableOp":
  609. # y = x*A^T + b
  610. assert transA != transB
  611. # Resolve linter issue
  612. if dtype is None or not isinstance(dtype, torch.dtype):
  613. raise TypeError(f"dtype must be a torch.dtype, but got {dtype}")
  614. bias = torch.rand(n, dtype=dtype, device=deviceid)
  615. X, matA = _create_matrices(
  616. m, n, k, lda, ldb, ldc, transA, transB, dtype, deviceid, subMatrix=subMatrix
  617. )
  618. matA = matA.t()
  619. torch.nn.functional.linear(X, matA, bias)
  620. else:
  621. warnings.warn(f"error: unknown op {op_sig}")
  622. def _check_tuning_assertions() -> None:
  623. r"""Helper function for multi-GPU tuning case. Need to check that TunableOp feature
  624. is enabled and that tuning is enabled.
  625. """
  626. if is_enabled() is False:
  627. warnings.warn("TunableOp was disabled. Trying to enable now.")
  628. enable(True)
  629. assert is_enabled() is True
  630. assert tuning_is_enabled() is True
  631. assert record_untuned_is_enabled() is False
  632. def mgpu_tune_gemm_in_file(filename_pattern: str, num_gpus: int) -> None:
  633. r"""Process one or more files and distribute work over one or more GPUs."""
  634. unique_gemm_entries = _gather_unique_untuned_gemm_from_files(filename_pattern)
  635. total_gpus = torch.cuda.device_count()
  636. assert 1 <= num_gpus <= total_gpus
  637. mp_context = mp.get_context("spawn")
  638. futures = [] # empty list to hold futures
  639. flush_results = [] # empty list to hold futures
  640. # GEMM are assigned to GPUs in a round robin manner
  641. h = 0
  642. with concurrent.futures.ProcessPoolExecutor(
  643. max_workers=num_gpus,
  644. mp_context=mp_context,
  645. initializer=_check_tuning_assertions,
  646. ) as executor:
  647. # The workers are a separate process. TunableOp will be
  648. # enabled in the child processes if PYTORCH_TUNABLEOP_ENABLED=1
  649. # In the initializer, we also try to enable TunableOP if th
  650. # environment variable was NOT set.
  651. for line in unique_gemm_entries:
  652. future = executor.submit(_process_single_offline_gemm, line, h)
  653. futures.append(future)
  654. h = (h + 1) % num_gpus
  655. for future in concurrent.futures.as_completed(futures):
  656. future.result()
  657. for g in range(num_gpus):
  658. flush_result = executor.submit(write_file)
  659. flush_results.append(flush_result)
  660. for flush_result in concurrent.futures.as_completed(flush_results):
  661. flush_result.result()
  662. torch.cuda.synchronize()
  663. _gather_tunableop_results()