binding.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547
  1. import os
  2. import traceback
  3. from collections import defaultdict
  4. from collections.abc import Iterable, Iterator
  5. from contextlib import contextmanager
  6. from dataclasses import asdict, dataclass
  7. from enum import Enum
  8. from logging import getLogger
  9. from typing import Callable, Optional, TypeVar
  10. import torch
  11. from torch._utils_internal import signpost_event
  12. __all__ = [
  13. "AffinityMode",
  14. "maybe_temporarily_apply_numa_binding_to_current_thread",
  15. "NumaOptions",
  16. ]
  17. logger = getLogger(__name__)
  18. class AffinityMode(str, Enum):
  19. """
  20. See behavior description for each affinity mode
  21. in torch.distributed.run.
  22. """
  23. NODE = "node"
  24. SOCKET = "socket"
  25. EXCLUSIVE = "exclusive"
  26. CORE_COMPLEX = "core-complex"
  27. @dataclass(frozen=True)
  28. class NumaOptions:
  29. affinity_mode: AffinityMode
  30. """
  31. If true, we will fall back to using the original command/entrypoint if we fail to compute
  32. or apply NUMA bindings.
  33. You should avoid using this option! It is only intended as a safety mechanism for facilitating
  34. mass rollouts of numa binding.
  35. """
  36. should_fall_back_if_binding_fails: bool = False
  37. @contextmanager
  38. def maybe_temporarily_apply_numa_binding_to_current_thread(
  39. *, gpu_index: int, numa_options: Optional[NumaOptions]
  40. ) -> Iterator[None]:
  41. """
  42. 1. Applies NUMA binding to the current thread, suitable for the thread
  43. which will be interacting with GPU gpu_index.
  44. 2. Resets to the original CPU affinity before exiting the context manager.
  45. """
  46. if numa_options is None:
  47. yield
  48. return
  49. original_logical_cpu_indices = _get_allowed_cpu_indices_for_current_thread()
  50. _apply_numa_binding_to_current_thread(
  51. gpu_index=gpu_index, numa_options=numa_options
  52. )
  53. yield
  54. _bind_current_thread_to_logical_cpus(
  55. logical_cpu_indices=original_logical_cpu_indices
  56. )
  57. def _apply_numa_binding_to_current_thread(
  58. *, gpu_index: int, numa_options: NumaOptions
  59. ) -> None:
  60. kwargs = {
  61. "gpu_index": gpu_index,
  62. "numa_options": asdict(numa_options),
  63. }
  64. logger.info("Attempting to apply NUMA binding, given input %r", kwargs)
  65. try:
  66. logical_cpu_indices = _get_logical_cpus_to_bind_to(
  67. gpu_index=gpu_index, numa_options=numa_options
  68. )
  69. logger.info(
  70. "Computed logical_cpu_indices=%s for NUMA binding",
  71. _get_ranges_str_from_ints(logical_cpu_indices),
  72. )
  73. _raise_if_logical_cpu_indices_invalid(logical_cpu_indices=logical_cpu_indices)
  74. logger.info(
  75. "Validated logical_cpu_indices=%s for NUMA binding",
  76. _get_ranges_str_from_ints(logical_cpu_indices),
  77. )
  78. _bind_current_thread_to_logical_cpus(logical_cpu_indices=logical_cpu_indices)
  79. logger.info(
  80. "Successfully bound to logical_cpu_indices=%s for NUMA binding",
  81. _get_ranges_str_from_ints(logical_cpu_indices),
  82. )
  83. signpost_event(
  84. category="numa_binding",
  85. name="apply_success",
  86. parameters={
  87. **kwargs,
  88. "logical_cpu_indices": _get_ranges_str_from_ints(logical_cpu_indices),
  89. },
  90. )
  91. except Exception:
  92. signpost_event(
  93. category="numa_binding",
  94. name="apply_exception",
  95. parameters={
  96. **kwargs,
  97. "traceback": traceback.format_exc(),
  98. },
  99. )
  100. logger.exception("Failed to apply NUMA binding for input=%r", kwargs)
  101. if numa_options.should_fall_back_if_binding_fails:
  102. logger.warning(
  103. "Continuing executing without applying NUMA binding, despite exception %s",
  104. traceback.format_exc(),
  105. )
  106. return None
  107. raise
  108. def _raise_if_logical_cpu_indices_invalid(*, logical_cpu_indices: set[int]) -> None:
  109. if not logical_cpu_indices:
  110. raise RuntimeError("Must bind to a non-empty set of CPU indices")
  111. def _bind_current_thread_to_logical_cpus(*, logical_cpu_indices: set[int]) -> None:
  112. # 0 represents the current thread
  113. os.sched_setaffinity(0, logical_cpu_indices)
  114. def _get_logical_cpus_to_bind_to(
  115. *,
  116. gpu_index: int,
  117. numa_options: NumaOptions,
  118. ) -> set[int]:
  119. """
  120. Args:
  121. gpu_index: The index of the GPU that will be used by the subprocess.
  122. Example: 0
  123. numa_options: See NumaOptions for details.
  124. Returns:
  125. Set of logical CPU indices to bind to.
  126. """
  127. if numa_options.affinity_mode == AffinityMode.NODE:
  128. logical_cpus = _node_get_logical_cpus_to_bind_to(gpu_index=gpu_index)
  129. elif numa_options.affinity_mode == AffinityMode.SOCKET:
  130. logical_cpus = _socket_get_logical_cpus_to_bind_to(gpu_index=gpu_index)
  131. elif numa_options.affinity_mode == AffinityMode.EXCLUSIVE:
  132. logical_cpus = _exclusive_get_logical_cpus_to_bind_to(gpu_index=gpu_index)
  133. elif numa_options.affinity_mode == AffinityMode.CORE_COMPLEX:
  134. logical_cpus = _core_complex_get_logical_cpus_to_bind_to(gpu_index=gpu_index)
  135. else:
  136. raise ValueError(f"Affinity mode {numa_options.affinity_mode} not supported.")
  137. return logical_cpus
  138. def _node_get_logical_cpus_to_bind_to(*, gpu_index: int) -> set[int]:
  139. """
  140. Core logic of 'node' numa strategy.
  141. """
  142. numa_node_index = _get_numa_node_index_for_gpu_index(gpu_index=gpu_index)
  143. return _get_allowed_logical_cpu_indices_for_numa_node(
  144. numa_node_index=numa_node_index
  145. )
  146. def _socket_get_logical_cpus_to_bind_to(*, gpu_index: int) -> set[int]:
  147. """
  148. Core logic of 'socket' numa strategy.
  149. """
  150. numa_node_index_of_gpu = _get_numa_node_index_for_gpu_index(gpu_index=gpu_index)
  151. socket_index = _get_socket_index_for_numa_node(
  152. numa_node_index=numa_node_index_of_gpu
  153. )
  154. numa_node_indices = _get_numa_node_indices_for_socket_index(
  155. socket_index=socket_index
  156. )
  157. logical_cpus = set()
  158. for numa_node_index in numa_node_indices:
  159. logical_cpus.update(
  160. _get_allowed_logical_cpu_indices_for_numa_node(
  161. numa_node_index=numa_node_index
  162. )
  163. )
  164. return logical_cpus
  165. def _exclusive_get_logical_cpus_to_bind_to(*, gpu_index: int) -> set[int]:
  166. """
  167. Core logic of 'exclusive' numa strategy.
  168. """
  169. numa_node_index = _get_numa_node_index_for_gpu_index(gpu_index=gpu_index)
  170. gpu_indices = _get_gpu_indices_for_numa_node(numa_node_index=numa_node_index)
  171. gpu_indices = sorted(gpu_indices)
  172. original_gpu_relative_index = gpu_indices.index(gpu_index)
  173. allowed_logical_cpu_indices = _get_allowed_logical_cpu_indices_for_numa_node(
  174. numa_node_index=numa_node_index
  175. )
  176. # Arbitrarily use the min logical cpu index on the physical core to
  177. # represent the physical core.
  178. physical_core_to_allowed_logical_cpu_indices = _group_by(
  179. allowed_logical_cpu_indices,
  180. lambda logical_cpu_index: min(
  181. _get_logical_cpu_indices_sharing_same_physical_core_as(
  182. logical_cpu_index=logical_cpu_index
  183. )
  184. ),
  185. )
  186. # Sort the dict for consistency (dicts maintain order in Python)
  187. physical_core_to_allowed_logical_cpu_indices = dict(
  188. sorted(physical_core_to_allowed_logical_cpu_indices.items())
  189. )
  190. num_physical_cores_per_gpu = len(
  191. physical_core_to_allowed_logical_cpu_indices
  192. ) // len(gpu_indices)
  193. # Often, the number of physical cores will not be perfectly divisible by the number
  194. # of GPUs. In those cases, give the lowest GPU indices an extra core
  195. num_gpus_to_give_one_extra_physical_core = len(
  196. physical_core_to_allowed_logical_cpu_indices
  197. ) % len(gpu_indices)
  198. if num_physical_cores_per_gpu < 1:
  199. raise RuntimeError(
  200. f"There are only {len(physical_core_to_allowed_logical_cpu_indices)} physical cores on {numa_node_index=},"
  201. + f" but there are {len(gpu_indices)} GPUs associated with this NUMA node."
  202. )
  203. # Compute slice indices for this GPU
  204. start = original_gpu_relative_index * num_physical_cores_per_gpu + min(
  205. original_gpu_relative_index, num_gpus_to_give_one_extra_physical_core
  206. )
  207. end = (
  208. start
  209. + num_physical_cores_per_gpu
  210. + (
  211. 1
  212. if original_gpu_relative_index < num_gpus_to_give_one_extra_physical_core
  213. else 0
  214. )
  215. )
  216. # Slice and flatten the logical CPUs from the selected physical cores
  217. logical_cpu_indices_for_original_gpu = {
  218. logical_cpu_index
  219. for logical_cpu_indices in list(
  220. physical_core_to_allowed_logical_cpu_indices.values()
  221. )[start:end]
  222. for logical_cpu_index in logical_cpu_indices
  223. }
  224. return logical_cpu_indices_for_original_gpu
  225. def _core_complex_get_logical_cpus_to_bind_to(*, gpu_index: int) -> set[int]:
  226. """
  227. Core logic of 'core-complex' numa strategy.
  228. Each GPU is assigned a full core complex (group of cores sharing L3 cache)
  229. within its affined NUMA node.
  230. """
  231. numa_node_index = _get_numa_node_index_for_gpu_index(gpu_index=gpu_index)
  232. gpu_indices = _get_gpu_indices_for_numa_node(numa_node_index=numa_node_index)
  233. gpu_indices = sorted(gpu_indices)
  234. original_gpu_relative_index = gpu_indices.index(gpu_index)
  235. allowed_logical_cpu_indices = _get_allowed_logical_cpu_indices_for_numa_node(
  236. numa_node_index=numa_node_index
  237. )
  238. # Arbitrarily use the min logical cpu index on the max level cache
  239. # to represent the max level cache.
  240. max_level_cache_to_allowed_logical_cpu_indices = _group_by(
  241. allowed_logical_cpu_indices,
  242. lambda logical_cpu_index: min(
  243. _get_logical_cpus_sharing_same_max_level_cache_as(
  244. logical_cpu_index=logical_cpu_index
  245. )
  246. ),
  247. )
  248. max_level_cache_to_allowed_logical_cpu_indices = dict(
  249. sorted(
  250. max_level_cache_to_allowed_logical_cpu_indices.items(),
  251. # First, prioritize caches with more available cpus
  252. # Second, prioritize lower index cpus (just for clarity/consistency)
  253. key=lambda item: (-len(item[1]), item[0]),
  254. )
  255. )
  256. cache_index_for_original_gpu = original_gpu_relative_index % len(
  257. max_level_cache_to_allowed_logical_cpu_indices
  258. )
  259. logical_cpu_indices_for_original_gpu = list(
  260. max_level_cache_to_allowed_logical_cpu_indices.values()
  261. )[cache_index_for_original_gpu]
  262. return logical_cpu_indices_for_original_gpu
  263. K = TypeVar("K")
  264. V = TypeVar("V")
  265. def _group_by(values: Iterable[V], get_key: Callable[[V], K]) -> dict[K, set[V]]:
  266. """
  267. Groups elements with same key into sets.
  268. """
  269. key_to_values: defaultdict[K, set[V]] = defaultdict(set)
  270. for value in values:
  271. key = get_key(value)
  272. key_to_values[key].add(value)
  273. return key_to_values
  274. def _get_logical_cpu_indices_sharing_same_physical_core_as(
  275. *, logical_cpu_index: int
  276. ) -> set[int]:
  277. thread_siblings_list_absolute_path = (
  278. f"/sys/devices/system/cpu/cpu{logical_cpu_index}/topology/thread_siblings_list"
  279. )
  280. with open(thread_siblings_list_absolute_path) as f:
  281. return _get_set_of_int_from_ranges_str(f.read())
  282. def _get_logical_cpus_sharing_same_max_level_cache_as(
  283. *, logical_cpu_index: int
  284. ) -> set[int]:
  285. cpu_cache_dir_absolute_path = (
  286. f"/sys/devices/system/cpu/cpu{logical_cpu_index}/cache"
  287. )
  288. max_level = -1
  289. logical_cpus_sharing_max_level_cache = set()
  290. for entry in os.listdir(cpu_cache_dir_absolute_path):
  291. if not entry.startswith("index") or not entry[5:].isdecimal():
  292. continue
  293. cache_index_absolute_path = os.path.join(cpu_cache_dir_absolute_path, entry)
  294. # Filter out other cache types like Instruction
  295. type_absolute_path = os.path.join(cache_index_absolute_path, "type")
  296. with open(type_absolute_path) as type_file:
  297. if type_file.read().strip() not in {"Unified", "Data"}:
  298. continue
  299. level_absolute_path = os.path.join(cache_index_absolute_path, "level")
  300. with open(level_absolute_path) as level_file:
  301. level = int(level_file.read())
  302. if level <= max_level:
  303. continue
  304. max_level = level
  305. shared_cpu_list_absolute_path = os.path.join(
  306. cache_index_absolute_path, "shared_cpu_list"
  307. )
  308. with open(shared_cpu_list_absolute_path) as share_cpu_list_file:
  309. logical_cpus_sharing_max_level_cache = _get_set_of_int_from_ranges_str(
  310. share_cpu_list_file.read()
  311. )
  312. return logical_cpus_sharing_max_level_cache
  313. def _get_allowed_logical_cpu_indices_for_numa_node(*, numa_node_index: int) -> set[int]:
  314. all_cpu_indices = _get_cpu_indices_for_numa_node_MAYBE_NOT_ALLOWED(
  315. numa_node_index=numa_node_index
  316. )
  317. allowed_cpu_indices = _get_allowed_cpu_indices_for_current_thread()
  318. return all_cpu_indices & allowed_cpu_indices
  319. def _get_cpu_indices_for_numa_node_MAYBE_NOT_ALLOWED(
  320. *, numa_node_index: int
  321. ) -> set[int]:
  322. """
  323. Returns:
  324. Indices of all CPUs associated with numa_node_index. However, the list
  325. is not filtered based on whether the thread is allowed to use them.
  326. """
  327. cpulist_absolute_path = f"/sys/devices/system/node/node{numa_node_index}/cpulist"
  328. try:
  329. with open(cpulist_absolute_path) as f:
  330. cpu_range_str = f.read()
  331. except FileNotFoundError as e:
  332. raise RuntimeError(
  333. f"Could not determine CPUs corresponding to {numa_node_index=}."
  334. ) from e
  335. return _get_set_of_int_from_ranges_str(cpu_range_str)
  336. def _get_gpu_count() -> int:
  337. return torch.cuda.device_count()
  338. def _get_numa_node_index_for_gpu_index(*, gpu_index: int) -> int:
  339. device_properties = torch.cuda.get_device_properties(gpu_index)
  340. domain = device_properties.pci_domain_id # type: ignore[attr-defined]
  341. bus = device_properties.pci_bus_id # type: ignore[attr-defined]
  342. device = device_properties.pci_device_id # type: ignore[attr-defined]
  343. # Format to sysfs PCI address: "0000:dc:00.0"
  344. pci_addr = f"{domain:04x}:{bus:02x}:{device:02x}.0"
  345. pci_numa_node_absolute_path = f"/sys/bus/pci/devices/{pci_addr}/numa_node"
  346. with open(pci_numa_node_absolute_path) as f:
  347. # In systems with only one NUMA node, this will
  348. # often be saved as -1. In those cases, there is obviously
  349. # at least one numa node, 0, so we use that.
  350. return max(int(f.read().strip()), 0)
  351. def _get_gpu_indices_for_numa_node(*, numa_node_index: int) -> set[int]:
  352. return {
  353. gpu_index
  354. for gpu_index in range(_get_gpu_count())
  355. if _get_numa_node_index_for_gpu_index(gpu_index=gpu_index) == numa_node_index
  356. }
  357. def _get_socket_index_for_numa_node(*, numa_node_index: int) -> int:
  358. arbitrary_cpu_index = _get_arbitrary_allowed_cpu_index_for_numa_node(
  359. numa_node_index=numa_node_index
  360. )
  361. return _get_socket_index_for_cpu(cpu_index=arbitrary_cpu_index)
  362. def _get_socket_index_for_cpu(*, cpu_index: int) -> int:
  363. package_id_absolute_path = (
  364. f"/sys/devices/system/cpu/cpu{cpu_index}/topology/physical_package_id"
  365. )
  366. try:
  367. with open(package_id_absolute_path) as f:
  368. return int(f.read().strip())
  369. except FileNotFoundError as e:
  370. raise RuntimeError(f"Could not determine socket for {cpu_index=}") from e
  371. def _get_arbitrary_allowed_cpu_index_for_numa_node(*, numa_node_index: int) -> int:
  372. return min(
  373. _get_allowed_logical_cpu_indices_for_numa_node(numa_node_index=numa_node_index)
  374. )
  375. def _get_set_of_int_from_ranges_str(ranges_str: str) -> set[int]:
  376. """
  377. Util for parsing a string of int ranges, as in a sysfs file.
  378. Args:
  379. ranges_str: E.g., "0-2,4,6-7"
  380. Returns:
  381. E.g., {0, 1, 2, 4, 6, 7}
  382. """
  383. ints: set[int] = set()
  384. for range_str in ranges_str.split(","):
  385. range_str = range_str.strip()
  386. if not range_str:
  387. continue
  388. if "-" in range_str:
  389. start_str, end_str = range_str.split("-")
  390. start, end = int(start_str), int(end_str)
  391. ints.update(range(start, end + 1))
  392. else:
  393. ints.add(int(range_str))
  394. return ints
  395. def _get_ranges_str_from_ints(ints: Iterable[int]) -> str:
  396. """
  397. Convert a set of integers to a compact string with ranges.
  398. Args:
  399. ints: E.g., {0, 1, 2, 4, 6, 7}
  400. Returns:
  401. E.g., "0-2,4,6-7"
  402. """
  403. if not ints:
  404. return ""
  405. sorted_ints = sorted(ints)
  406. ranges = []
  407. start = prev = sorted_ints[0]
  408. for num in sorted_ints[1:]:
  409. if num == prev + 1:
  410. prev = num
  411. else:
  412. if start == prev:
  413. ranges.append(f"{start}")
  414. else:
  415. ranges.append(f"{start}-{prev}")
  416. start = prev = num
  417. # Append the last range
  418. if start == prev:
  419. ranges.append(f"{start}")
  420. else:
  421. ranges.append(f"{start}-{prev}")
  422. return ",".join(ranges)
  423. def _get_systemwide_numa_node_indices() -> set[int]:
  424. with open("/sys/devices/system/node/possible") as f:
  425. possible_nodes_str = f.read()
  426. return _get_set_of_int_from_ranges_str(possible_nodes_str)
  427. def _get_numa_node_indices_for_socket_index(*, socket_index: int) -> set[int]:
  428. systemwide_numa_node_indices = _get_systemwide_numa_node_indices()
  429. matching_numa_node_indices = set()
  430. for numa_node_index in systemwide_numa_node_indices:
  431. arbitrary_cpu_index = _get_arbitrary_allowed_cpu_index_for_numa_node(
  432. numa_node_index=numa_node_index
  433. )
  434. if socket_index == _get_socket_index_for_cpu(cpu_index=arbitrary_cpu_index):
  435. matching_numa_node_indices.add(numa_node_index)
  436. return matching_numa_node_indices
  437. def _get_allowed_cpu_indices_for_current_thread() -> set[int]:
  438. # 0 denotes current thread
  439. return os.sched_getaffinity(0)