env.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from .deps import is_dep_available, require_deps
  15. def get_device_type():
  16. import paddle
  17. device_str = paddle.get_device()
  18. return device_str.split(":")[0]
  19. def get_paddle_version():
  20. import paddle
  21. version = paddle.__version__
  22. if "-" in version:
  23. version, tag = version.split("-")
  24. else:
  25. tag = None
  26. version = version.split(".")
  27. assert len(version) == 3
  28. major_v, minor_v, patch_v = map(int, version)
  29. if tag:
  30. return major_v, minor_v, patch_v, tag
  31. else:
  32. return major_v, minor_v, patch_v, None
  33. def get_paddle_cuda_version():
  34. import paddle.version
  35. cuda_version = paddle.version.cuda()
  36. if cuda_version == "False":
  37. return None
  38. return tuple(map(int, cuda_version.split(".")))
  39. def get_paddle_cudnn_version():
  40. import paddle.version
  41. cudnn_version = paddle.version.cudnn()
  42. if cudnn_version == "False":
  43. return None
  44. return tuple(map(int, cudnn_version.split(".")))
  45. # Should we also support getting the runtime versions of CUDA and cuDNN?
  46. def is_cuda_available():
  47. if is_dep_available("paddlepaddle"):
  48. import paddle.device
  49. # TODO: Check runtime availability
  50. return (
  51. paddle.device.is_compiled_with_cuda() and not paddle.is_compiled_with_rocm()
  52. )
  53. else:
  54. # If Paddle is unavailable, check GPU availability using PyTorch API.
  55. require_deps("torch")
  56. import torch.cuda
  57. import torch.version
  58. # Distinguish GPUs and DCUs by checking `torch.version.cuda`
  59. return torch.cuda.is_available() and torch.version.cuda
  60. def get_gpu_compute_capability():
  61. cap = None
  62. if is_cuda_available():
  63. if is_dep_available("paddlepaddle"):
  64. import paddle.device
  65. cap = paddle.device.cuda.get_device_capability()
  66. else:
  67. # If Paddle is unavailable, retrieve GPU compute capability from PyTorch instead.
  68. require_deps("torch")
  69. import torch.cuda
  70. cap = torch.cuda.get_device_capability()
  71. return cap