tvm.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. """
  2. This module provides TVM backend integration for TorchDynamo.
  3. Apache TVM is a deep learning compiler framework that can optimize and execute
  4. models on various hardware backends. This module enables:
  5. - Compilation of PyTorch models to TVM's computation graphs
  6. - Multiple scheduling options:
  7. - Default scheduler
  8. - Auto-scheduler for automatic optimization
  9. - Meta-schedule for evolutionary search-based tuning
  10. - Hardware-specific optimizations:
  11. - CUDA GPU support
  12. - CPU support with LLVM targeting and architecture-specific tuning
  13. - Automatic detection of CPU capabilities (AVX2, AVX512)
  14. - Tensor conversion utilities between PyTorch and TVM formats
  15. - Configurable optimization levels and tuning trials
  16. The backend can be used with torch.compile():
  17. model = torch.compile(model, backend="tvm")
  18. """
  19. import functools
  20. import importlib
  21. import logging
  22. import os
  23. import sys
  24. import tempfile
  25. from types import MappingProxyType
  26. from typing import Any, Callable, Optional
  27. import torch
  28. from torch import fx
  29. from .common import device_from_inputs, fake_tensor_unsupported
  30. from .registry import register_backend
  31. log = logging.getLogger(__name__)
  32. @register_backend
  33. @fake_tensor_unsupported # type: ignore[arg-type]
  34. def tvm(
  35. gm: fx.GraphModule,
  36. example_inputs: list[torch.Tensor],
  37. *,
  38. options: Optional[MappingProxyType[str, Any]] = None,
  39. ) -> Callable[..., Any]:
  40. if options is None:
  41. options = MappingProxyType({"scheduler": None, "trials": 20000, "opt_level": 3})
  42. assert options is not None
  43. import tvm # type: ignore[import]
  44. from tvm import relay # type: ignore[import]
  45. from tvm.contrib import graph_executor # type: ignore[import]
  46. jit_mod = torch.jit.trace(gm, example_inputs)
  47. device = device_from_inputs(example_inputs)
  48. shape_list = [(f"inp_{idx}", i.shape) for idx, i in enumerate(example_inputs)]
  49. example_outputs = gm(*example_inputs)
  50. if len(example_outputs) == 0:
  51. log.warning("Explicitly fall back to eager due to zero output")
  52. return gm.forward
  53. mod, params = relay.frontend.from_pytorch(jit_mod, shape_list)
  54. if device.type == "cuda":
  55. dev = tvm.cuda(device.index)
  56. target = tvm.target.cuda()
  57. else:
  58. dev = tvm.cpu(0)
  59. target = tvm.target.Target(llvm_target())
  60. scheduler = options.get("scheduler", None)
  61. if scheduler is None:
  62. scheduler = os.environ.get("TVM_SCHEDULER", None)
  63. trials = options.get("trials", 20000)
  64. opt_level = options.get("opt_level", 3)
  65. if scheduler == "auto_scheduler":
  66. from tvm import auto_scheduler
  67. log_file = tempfile.NamedTemporaryFile()
  68. if not os.path.exists(log_file):
  69. tasks, task_weights = auto_scheduler.extract_tasks(
  70. mod["main"], params, target
  71. )
  72. if len(tasks) != 0:
  73. tuner = auto_scheduler.TaskScheduler(tasks, task_weights)
  74. if not os.path.exists(log_file):
  75. assert trials > 0
  76. tune_option = auto_scheduler.TuningOptions(
  77. num_measure_trials=trials,
  78. measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
  79. early_stopping=2000,
  80. )
  81. try:
  82. tuner.tune(tune_option)
  83. except Exception:
  84. if os.path.exists(log_file):
  85. os.unlink(log_file)
  86. raise
  87. with auto_scheduler.ApplyHistoryBest(log_file):
  88. with tvm.transform.PassContext(
  89. opt_level=opt_level, config={"relay.backend.use_auto_scheduler": True}
  90. ):
  91. lib = relay.build(mod, target=target, params=params)
  92. elif scheduler == "meta_schedule":
  93. from tvm import meta_schedule as ms
  94. with tempfile.TemporaryDirectory() as work_dir:
  95. if device.type != "cuda":
  96. # meta_schedule needs num-cores to be specified
  97. # here we use the maximum core count
  98. target = tvm.target.Target(
  99. f"{llvm_target()} --num-cores {ms.utils.cpu_count(logical=False)}"
  100. )
  101. # TODO(shingjan): This could be replaced by tvm.contrib.torch.optimize_torch
  102. # once USE_PT_TVMDSOOP is updated and turned on by default in TVM.
  103. assert trials > 0
  104. database = ms.relay_integration.tune_relay(
  105. mod=mod,
  106. target=target,
  107. work_dir=work_dir,
  108. max_trials_global=trials,
  109. num_trials_per_iter=64,
  110. params=params,
  111. strategy="evolutionary",
  112. opt_level=opt_level,
  113. )
  114. lib = ms.relay_integration.compile_relay(
  115. database=database,
  116. mod=mod,
  117. target=target,
  118. params=params,
  119. opt_level=opt_level,
  120. )
  121. elif scheduler == "default" or not scheduler:
  122. # no autotuning
  123. with tvm.transform.PassContext(opt_level=opt_level):
  124. lib = relay.build(mod, target=target, params=params)
  125. else:
  126. raise NotImplementedError(
  127. "This tuning option is invalid/not implemented for torchdynamo's TVM-related backend. "
  128. "There are three available options: default, auto_scheduler and meta_schedule."
  129. )
  130. m = graph_executor.GraphModule(lib["default"](dev))
  131. def to_torch_tensor(nd_tensor: tvm.nd.array) -> torch.Tensor:
  132. """A helper function to transfer a NDArray to torch.tensor."""
  133. if nd_tensor.dtype == "bool":
  134. # DLPack does not support boolean so it can't be handled by
  135. # torch.utils.dlpack.from_pack. Workaround by going through
  136. # numpy, although this brings additional data copy overhead.
  137. return torch.from_numpy(nd_tensor.numpy())
  138. return torch.utils.dlpack.from_dlpack(nd_tensor.to_dlpack())
  139. def to_tvm_tensor(torch_tensor: torch.Tensor) -> tvm.nd.array:
  140. """A helper function to transfer a torch.tensor to NDArray."""
  141. if torch_tensor.dtype == torch.bool:
  142. # same reason as above, fallback to numpy conversion which
  143. # could introduce data copy overhead
  144. return tvm.nd.array(torch_tensor.cpu().numpy())
  145. return tvm.nd.from_dlpack(torch_tensor)
  146. def exec_tvm(*i_args: torch.Tensor) -> list[torch.Tensor]:
  147. args = [a.contiguous() for a in i_args]
  148. shape_info, _ = m.get_input_info()
  149. active_inputs = {name for name, _ in shape_info.items()}
  150. for idx, arg in enumerate(args, 0):
  151. if arg.dim() != 0:
  152. if arg.requires_grad:
  153. arg = arg.detach()
  154. inp_name = f"inp_{idx}"
  155. if inp_name not in active_inputs:
  156. log.warning(
  157. "input %s skipped as not found in tvm's runtime library",
  158. inp_name,
  159. )
  160. continue
  161. m.set_input(
  162. inp_name,
  163. to_tvm_tensor(arg),
  164. )
  165. m.run()
  166. return [to_torch_tensor(m.get_output(i)) for i in range(m.get_num_outputs())]
  167. return exec_tvm
  168. tvm_meta_schedule = functools.partial(tvm, scheduler="meta_schedule")
  169. tvm_auto_scheduler = functools.partial(tvm, scheduler="auto_scheduler")
  170. def has_tvm() -> bool:
  171. try:
  172. importlib.import_module("tvm")
  173. return True
  174. except ImportError:
  175. return False
  176. @functools.cache
  177. def llvm_target() -> str:
  178. if sys.platform == "linux":
  179. cpuinfo = open("/proc/cpuinfo").read()
  180. if "avx512" in cpuinfo:
  181. return "llvm -mcpu=skylake-avx512"
  182. elif "avx2" in cpuinfo:
  183. return "llvm -mcpu=core-avx2"
  184. return "llvm"