versions.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. # Copyright 2022 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import importlib.metadata
  15. from typing import Union
  16. from packaging.version import Version, parse
  17. from .constants import STR_OPERATION_TO_FUNC
  18. torch_version = parse(importlib.metadata.version("torch"))
  19. def compare_versions(library_or_version: Union[str, Version], operation: str, requirement_version: str):
  20. """
  21. Compares a library version to some requirement using a given operation.
  22. Args:
  23. library_or_version (`str` or `packaging.version.Version`):
  24. A library name or a version to check.
  25. operation (`str`):
  26. A string representation of an operator, such as `">"` or `"<="`.
  27. requirement_version (`str`):
  28. The version to compare the library version against
  29. """
  30. if operation not in STR_OPERATION_TO_FUNC.keys():
  31. raise ValueError(f"`operation` must be one of {list(STR_OPERATION_TO_FUNC.keys())}, received {operation}")
  32. operation = STR_OPERATION_TO_FUNC[operation]
  33. if isinstance(library_or_version, str):
  34. library_or_version = parse(importlib.metadata.version(library_or_version))
  35. return operation(library_or_version, parse(requirement_version))
  36. def is_torch_version(operation: str, version: str):
  37. """
  38. Compares the current PyTorch version to a given reference with an operation.
  39. Args:
  40. operation (`str`):
  41. A string representation of an operator, such as `">"` or `"<="`
  42. version (`str`):
  43. A string version of PyTorch
  44. """
  45. return compare_versions(torch_version, operation, version)