_cli_utils.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. # Copyright 2022 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Contains a utility for good-looking prints."""
  15. import os
  16. from typing import List, Union
  17. class ANSI:
  18. """
  19. Helper for en.wikipedia.org/wiki/ANSI_escape_code
  20. """
  21. _bold = "\u001b[1m"
  22. _gray = "\u001b[90m"
  23. _red = "\u001b[31m"
  24. _reset = "\u001b[0m"
  25. _yellow = "\u001b[33m"
  26. @classmethod
  27. def bold(cls, s: str) -> str:
  28. return cls._format(s, cls._bold)
  29. @classmethod
  30. def gray(cls, s: str) -> str:
  31. return cls._format(s, cls._gray)
  32. @classmethod
  33. def red(cls, s: str) -> str:
  34. return cls._format(s, cls._bold + cls._red)
  35. @classmethod
  36. def yellow(cls, s: str) -> str:
  37. return cls._format(s, cls._yellow)
  38. @classmethod
  39. def _format(cls, s: str, code: str) -> str:
  40. if os.environ.get("NO_COLOR"):
  41. # See https://no-color.org/
  42. return s
  43. return f"{code}{s}{cls._reset}"
  44. def tabulate(rows: List[List[Union[str, int]]], headers: List[str]) -> str:
  45. """
  46. Inspired by:
  47. - stackoverflow.com/a/8356620/593036
  48. - stackoverflow.com/questions/9535954/printing-lists-as-tabular-data
  49. """
  50. col_widths = [max(len(str(x)) for x in col) for col in zip(*rows, headers)]
  51. row_format = ("{{:{}}} " * len(headers)).format(*col_widths)
  52. lines = []
  53. lines.append(row_format.format(*headers))
  54. lines.append(row_format.format(*["-" * w for w in col_widths]))
  55. for row in rows:
  56. lines.append(row_format.format(*row))
  57. return "\n".join(lines)