deprecation.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. # Copyright 2024 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import inspect
  15. import warnings
  16. from functools import wraps
  17. from typing import Optional
  18. import packaging.version
  19. from .. import __version__
  20. from . import ExplicitEnum, is_torch_available, is_torchdynamo_compiling
  21. # This is needed in case we deprecate a kwarg of a function/method being compiled
  22. if is_torch_available():
  23. import torch # noqa: F401
  24. class Action(ExplicitEnum):
  25. NONE = "none"
  26. NOTIFY = "notify"
  27. NOTIFY_ALWAYS = "notify_always"
  28. RAISE = "raise"
  29. def deprecate_kwarg(
  30. old_name: str,
  31. version: str,
  32. new_name: Optional[str] = None,
  33. warn_if_greater_or_equal_version: bool = False,
  34. raise_if_greater_or_equal_version: bool = False,
  35. raise_if_both_names: bool = False,
  36. additional_message: Optional[str] = None,
  37. ):
  38. """
  39. Function or method decorator to notify users about deprecated keyword arguments, replacing them with a new name if specified.
  40. Note that is decorator is `torch.compile`-safe, i.e. it will not cause graph breaks (but no warning will be displayed if compiling).
  41. This decorator allows you to:
  42. - Notify users when a keyword argument is deprecated.
  43. - Automatically replace deprecated keyword arguments with new ones.
  44. - Raise an error if deprecated arguments are used, depending on the specified conditions.
  45. By default, the decorator notifies the user about the deprecated argument while the `transformers.__version__` < specified `version`
  46. in the decorator. To keep notifications with any version `warn_if_greater_or_equal_version=True` can be set.
  47. Parameters:
  48. old_name (`str`):
  49. Name of the deprecated keyword argument.
  50. version (`str`):
  51. The version in which the keyword argument was (or will be) deprecated.
  52. new_name (`Optional[str]`, *optional*):
  53. The new name for the deprecated keyword argument. If specified, the deprecated keyword argument will be replaced with this new name.
  54. warn_if_greater_or_equal_version (`bool`, *optional*, defaults to `False`):
  55. Whether to show warning if current `transformers` version is greater or equal to the deprecated version.
  56. raise_if_greater_or_equal_version (`bool`, *optional*, defaults to `False`):
  57. Whether to raise `ValueError` if current `transformers` version is greater or equal to the deprecated version.
  58. raise_if_both_names (`bool`, *optional*, defaults to `False`):
  59. Whether to raise `ValueError` if both deprecated and new keyword arguments are set.
  60. additional_message (`Optional[str]`, *optional*):
  61. An additional message to append to the default deprecation message.
  62. Raises:
  63. ValueError:
  64. If raise_if_greater_or_equal_version is True and the current version is greater than or equal to the deprecated version, or if raise_if_both_names is True and both old and new keyword arguments are provided.
  65. Returns:
  66. Callable:
  67. A wrapped function that handles the deprecated keyword arguments according to the specified parameters.
  68. Example usage with renaming argument:
  69. ```python
  70. @deprecate_kwarg("reduce_labels", new_name="do_reduce_labels", version="6.0.0")
  71. def my_function(do_reduce_labels):
  72. print(do_reduce_labels)
  73. my_function(reduce_labels=True) # Will show a deprecation warning and use do_reduce_labels=True
  74. ```
  75. Example usage without renaming argument:
  76. ```python
  77. @deprecate_kwarg("max_size", version="6.0.0")
  78. def my_function(max_size):
  79. print(max_size)
  80. my_function(max_size=1333) # Will show a deprecation warning
  81. ```
  82. """
  83. deprecated_version = packaging.version.parse(version)
  84. current_version = packaging.version.parse(__version__)
  85. is_greater_or_equal_version = current_version >= deprecated_version
  86. if is_greater_or_equal_version:
  87. version_message = f"and removed starting from version {version}"
  88. else:
  89. version_message = f"and will be removed in version {version}"
  90. def wrapper(func):
  91. # Required for better warning message
  92. sig = inspect.signature(func)
  93. function_named_args = set(sig.parameters.keys())
  94. is_instance_method = "self" in function_named_args
  95. is_class_method = "cls" in function_named_args
  96. @wraps(func)
  97. def wrapped_func(*args, **kwargs):
  98. # Get class + function name (just for better warning message)
  99. func_name = func.__name__
  100. if is_instance_method:
  101. func_name = f"{args[0].__class__.__name__}.{func_name}"
  102. elif is_class_method:
  103. func_name = f"{args[0].__name__}.{func_name}"
  104. minimum_action = Action.NONE
  105. message = None
  106. # deprecated kwarg and its new version are set for function call -> replace it with new name
  107. if old_name in kwargs and new_name in kwargs:
  108. minimum_action = Action.RAISE if raise_if_both_names else Action.NOTIFY_ALWAYS
  109. message = f"Both `{old_name}` and `{new_name}` are set for `{func_name}`. Using `{new_name}={kwargs[new_name]}` and ignoring deprecated `{old_name}={kwargs[old_name]}`."
  110. kwargs.pop(old_name)
  111. # only deprecated kwarg is set for function call -> replace it with new name
  112. elif old_name in kwargs and new_name is not None and new_name not in kwargs:
  113. minimum_action = Action.NOTIFY
  114. message = f"`{old_name}` is deprecated {version_message} for `{func_name}`. Use `{new_name}` instead."
  115. kwargs[new_name] = kwargs.pop(old_name)
  116. # deprecated kwarg is not set for function call and new name is not specified -> just notify
  117. elif old_name in kwargs:
  118. minimum_action = Action.NOTIFY
  119. message = f"`{old_name}` is deprecated {version_message} for `{func_name}`."
  120. if message is not None and additional_message is not None:
  121. message = f"{message} {additional_message}"
  122. # update minimum_action if argument is ALREADY deprecated (current version >= deprecated version)
  123. if is_greater_or_equal_version:
  124. # change to (NOTIFY, NOTIFY_ALWAYS) -> RAISE if specified
  125. # in case we want to raise error for already deprecated arguments
  126. if raise_if_greater_or_equal_version and minimum_action != Action.NONE:
  127. minimum_action = Action.RAISE
  128. # change to NOTIFY -> NONE if specified (NOTIFY_ALWAYS can't be changed to NONE)
  129. # in case we want to ignore notifications for already deprecated arguments
  130. elif not warn_if_greater_or_equal_version and minimum_action == Action.NOTIFY:
  131. minimum_action = Action.NONE
  132. # raise error or notify user
  133. if minimum_action == Action.RAISE:
  134. raise ValueError(message)
  135. # If we are compiling, we do not raise the warning as it would break compilation
  136. elif minimum_action in (Action.NOTIFY, Action.NOTIFY_ALWAYS) and not is_torchdynamo_compiling():
  137. # DeprecationWarning is ignored by default, so we use FutureWarning instead
  138. warnings.warn(message, FutureWarning, stacklevel=2)
  139. return func(*args, **kwargs)
  140. return wrapped_func
  141. return wrapper