utils.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. # Copyright 2024 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import argparse
  15. class _StoreAction(argparse.Action):
  16. """
  17. Custom action that allows for `-` or `_` to be passed in for an argument.
  18. """
  19. def __init__(self, *args, **kwargs):
  20. super().__init__(*args, **kwargs)
  21. new_option_strings = []
  22. for option_string in self.option_strings:
  23. new_option_strings.append(option_string)
  24. if "_" in option_string[2:]:
  25. # Add `-` version to the option string
  26. new_option_strings.append(option_string.replace("_", "-"))
  27. self.option_strings = new_option_strings
  28. def __call__(self, parser, namespace, values, option_string=None):
  29. setattr(namespace, self.dest, values)
  30. if not hasattr(namespace, "nondefault"):
  31. namespace.nondefault = set()
  32. namespace.nondefault.add(self.dest)
  33. class _StoreConstAction(_StoreAction):
  34. """
  35. Same as `argparse._StoreConstAction` but uses the custom `_StoreAction`.
  36. """
  37. def __init__(self, option_strings, dest, const, default=None, required=False, help=None):
  38. super().__init__(
  39. option_strings=option_strings,
  40. dest=dest,
  41. nargs=0,
  42. const=const,
  43. default=default,
  44. required=required,
  45. help=help,
  46. )
  47. def __call__(self, parser, namespace, values, option_string=None):
  48. super().__call__(parser, namespace, self.const, option_string)
  49. class _StoreTrueAction(_StoreConstAction):
  50. """
  51. Same as `argparse._StoreTrueAction` but uses the custom `_StoreConstAction`.
  52. """
  53. def __init__(
  54. self,
  55. option_strings,
  56. dest,
  57. default=None,
  58. required=False,
  59. help=None,
  60. ):
  61. super().__init__(
  62. option_strings=option_strings, dest=dest, const=True, default=default, required=required, help=help
  63. )
  64. class CustomArgumentGroup(argparse._ArgumentGroup):
  65. """
  66. Custom argument group that allows for the use of `-` or `_` in arguments passed and overrides the help for each
  67. when applicable.
  68. """
  69. def _add_action(self, action):
  70. args = vars(action)
  71. if isinstance(action, argparse._StoreTrueAction):
  72. action = _StoreTrueAction(
  73. args["option_strings"], args["dest"], args["default"], args["required"], args["help"]
  74. )
  75. elif isinstance(action, argparse._StoreConstAction):
  76. action = _StoreConstAction(
  77. args["option_strings"],
  78. args["dest"],
  79. args["const"],
  80. args["default"],
  81. args["required"],
  82. args["help"],
  83. )
  84. elif isinstance(action, argparse._StoreAction):
  85. action = _StoreAction(**args)
  86. action = super()._add_action(action)
  87. return action
  88. class CustomArgumentParser(argparse.ArgumentParser):
  89. """
  90. Custom argument parser that allows for the use of `-` or `_` in arguments passed and overrides the help for each
  91. when applicable.
  92. """
  93. def add_argument(self, *args, **kwargs):
  94. if "action" in kwargs:
  95. # Translate action -> class
  96. if kwargs["action"] == "store_true":
  97. kwargs["action"] = _StoreTrueAction
  98. else:
  99. kwargs["action"] = _StoreAction
  100. super().add_argument(*args, **kwargs)
  101. def add_argument_group(self, *args, **kwargs):
  102. group = CustomArgumentGroup(self, *args, **kwargs)
  103. self._action_groups.append(group)
  104. return group