device.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. from contextlib import ContextDecorator
  16. from . import logging
  17. from .custom_device_list import (
  18. DCU_WHITELIST,
  19. GCU_WHITELIST,
  20. METAX_GPU_WHITELIST,
  21. MLU_WHITELIST,
  22. NPU_BLACKLIST,
  23. XPU_WHITELIST,
  24. )
  25. from .flags import DISABLE_DEV_MODEL_WL
  26. SUPPORTED_DEVICE_TYPE = [
  27. "cpu",
  28. "gpu",
  29. "xpu",
  30. "npu",
  31. "mlu",
  32. "gcu",
  33. "dcu",
  34. "iluvatar_gpu",
  35. "metax_gpu",
  36. ]
  37. def constr_device(device_type, device_ids):
  38. if device_type == "cpu" and device_ids is not None:
  39. raise ValueError("`device_ids` must be None for CPUs")
  40. if device_ids:
  41. device_ids = ",".join(map(str, device_ids))
  42. return f"{device_type}:{device_ids}"
  43. else:
  44. return f"{device_type}"
  45. def get_default_device():
  46. import paddle
  47. if paddle.device.is_compiled_with_cuda() and paddle.device.cuda.device_count() > 0:
  48. return constr_device("gpu", [0])
  49. else:
  50. return "cpu"
  51. def parse_device(device):
  52. """parse_device"""
  53. # According to https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/device/set_device_cn.html
  54. parts = device.split(":")
  55. if len(parts) > 2:
  56. raise ValueError(f"Invalid device: {device}")
  57. if len(parts) == 1:
  58. device_type, device_ids = parts[0], None
  59. else:
  60. device_type, device_ids = parts
  61. device_ids = device_ids.split(",")
  62. for device_id in device_ids:
  63. if not device_id.isdigit():
  64. raise ValueError(
  65. f"Device ID must be an integer. Invalid device ID: {device_id}"
  66. )
  67. device_ids = list(map(int, device_ids))
  68. device_type = device_type.lower()
  69. # raise_unsupported_device_error(device_type, SUPPORTED_DEVICE_TYPE)
  70. assert device_type.lower() in SUPPORTED_DEVICE_TYPE
  71. if device_type == "cpu" and device_ids is not None:
  72. raise ValueError("No Device ID should be specified for CPUs")
  73. return device_type, device_ids
  74. def update_device_num(device, num):
  75. device_type, device_ids = parse_device(device)
  76. if device_ids:
  77. assert len(device_ids) >= num
  78. return constr_device(device_type, device_ids[:num])
  79. else:
  80. return constr_device(device_type, device_ids)
  81. def set_env_for_device(device):
  82. device_type, _ = parse_device(device)
  83. return set_env_for_device_type(device_type)
  84. def set_env_for_device_type(device_type):
  85. import paddle
  86. def _set(envs):
  87. for key, val in envs.items():
  88. os.environ[key] = val
  89. logging.debug(f"{key} has been set to {val}.")
  90. # XXX: is_compiled_with_rocm() must be True on dcu platform ?
  91. if device_type.lower() == "dcu" and paddle.is_compiled_with_rocm():
  92. envs = {"FLAGS_conv_workspace_size_limit": "2000"}
  93. _set(envs)
  94. if device_type.lower() == "npu":
  95. envs = {
  96. "FLAGS_npu_jit_compile": "0",
  97. "FLAGS_use_stride_kernel": "0",
  98. "FLAGS_allocator_strategy": "auto_growth",
  99. "CUSTOM_DEVICE_BLACK_LIST": "pad3d,pad3d_grad,set_value,set_value_with_tensor",
  100. "FLAGS_npu_scale_aclnn": "True",
  101. "FLAGS_npu_split_aclnn": "True",
  102. }
  103. _set(envs)
  104. if device_type.lower() == "xpu":
  105. envs = {
  106. "BKCL_FORCE_SYNC": "1",
  107. "BKCL_TIMEOUT": "1800",
  108. "FLAGS_use_stride_kernel": "0",
  109. "XPU_BLACK_LIST": "pad3d",
  110. }
  111. _set(envs)
  112. if device_type.lower() == "metax_gpu":
  113. envs = {"FLAGS_use_stride_kernel": "0"}
  114. _set(envs)
  115. if device_type.lower() == "mlu":
  116. envs = {
  117. "FLAGS_use_stride_kernel": "0",
  118. "FLAGS_use_stream_safe_cuda_allocator": "0",
  119. }
  120. _set(envs)
  121. if device_type.lower() == "gcu":
  122. envs = {"FLAGS_use_stride_kernel": "0"}
  123. _set(envs)
  124. def check_supported_device_type(device_type, model_name):
  125. if DISABLE_DEV_MODEL_WL:
  126. logging.warning(
  127. "Skip checking if model is supported on device because the flag `PADDLE_PDX_DISABLE_DEV_MODEL_WL` has been set."
  128. )
  129. return
  130. tips = "You could set env `PADDLE_PDX_DISABLE_DEV_MODEL_WL` to `true` to disable this checking."
  131. if device_type == "dcu":
  132. assert model_name in DCU_WHITELIST, (
  133. f"The DCU device does not yet support `{model_name}` model!" + tips
  134. )
  135. elif device_type == "mlu":
  136. assert model_name in MLU_WHITELIST, (
  137. f"The MLU device does not yet support `{model_name}` model!" + tips
  138. )
  139. elif device_type == "metax_gpu":
  140. assert model_name in METAX_GPU_WHITELIST, (
  141. f"The METAX_GPU device does not yet support `{model_name}` model!" + tips
  142. )
  143. elif device_type == "npu":
  144. assert model_name not in NPU_BLACKLIST, (
  145. f"The NPU device does not yet support `{model_name}` model!" + tips
  146. )
  147. elif device_type == "xpu":
  148. assert model_name in XPU_WHITELIST, (
  149. f"The XPU device does not yet support `{model_name}` model!" + tips
  150. )
  151. elif device_type == "gcu":
  152. assert model_name in GCU_WHITELIST, (
  153. f"The GCU device does not yet support `{model_name}` model!" + tips
  154. )
  155. def check_supported_device(device, model_name):
  156. device_type, _ = parse_device(device)
  157. return check_supported_device_type(device_type, model_name)
  158. class TemporaryDeviceChanger(ContextDecorator):
  159. """
  160. A context manager to temporarily change global device
  161. """
  162. def __init__(self, new_device):
  163. # if new_device is None, nothing changed
  164. import paddle
  165. self.new_device = new_device
  166. self.original_device = paddle.device.get_device()
  167. def __enter__(self):
  168. import paddle
  169. if self.new_device is None:
  170. return self
  171. paddle.device.set_device(self.new_device)
  172. return self
  173. def __exit__(self, exc_type, exc_val, exc_tb):
  174. import paddle
  175. if self.new_device is None:
  176. return False
  177. paddle.device.set_device(self.original_device)
  178. return False