examples.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. #!/usr/bin/env python
  2. # Copyright 2022 The HuggingFace Team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """
  16. A collection of utilities for comparing `examples/complete_*_example.py` scripts with the capabilities inside of each
  17. `examples/by_feature` example. `compare_against_test` is the main function that should be used when testing, while the
  18. others are used to either get the code that matters, or to preprocess them (such as stripping comments)
  19. """
  20. import os
  21. from typing import Optional
  22. def get_function_contents_by_name(lines: list[str], name: str):
  23. """
  24. Extracts a function from `lines` of segmented source code with the name `name`.
  25. Args:
  26. lines (`List[str]`):
  27. Source code of a script separated by line.
  28. name (`str`):
  29. The name of the function to extract. Should be either `training_function` or `main`
  30. """
  31. if name != "training_function" and name != "main":
  32. raise ValueError(f"Incorrect function name passed: {name}, choose either 'main' or 'training_function'")
  33. good_lines, found_start = [], False
  34. for line in lines:
  35. if not found_start and f"def {name}" in line:
  36. found_start = True
  37. good_lines.append(line)
  38. continue
  39. if found_start:
  40. if name == "training_function" and "def main" in line:
  41. return good_lines
  42. if name == "main" and "if __name__" in line:
  43. return good_lines
  44. good_lines.append(line)
  45. def clean_lines(lines: list[str]):
  46. """
  47. Filters `lines` and removes any entries that start with a comment ('#') or is just a newline ('\n')
  48. Args:
  49. lines (`List[str]`):
  50. Source code of a script separated by line.
  51. """
  52. return [line for line in lines if not line.lstrip().startswith("#") and line != "\n"]
  53. def compare_against_test(
  54. base_filename: str, feature_filename: str, parser_only: bool, secondary_filename: Optional[str] = None
  55. ):
  56. """
  57. Tests whether the additional code inside of `feature_filename` was implemented in `base_filename`. This should be
  58. used when testing to see if `complete_*_.py` examples have all of the implementations from each of the
  59. `examples/by_feature/*` scripts.
  60. It utilizes `nlp_example.py` to extract out all of the repeated training code, so that only the new additional code
  61. is examined and checked. If something *other* than `nlp_example.py` should be used, such as `cv_example.py` for the
  62. `complete_cv_example.py` script, it should be passed in for the `secondary_filename` parameter.
  63. Args:
  64. base_filename (`str` or `os.PathLike`):
  65. The filepath of a single "complete" example script to test, such as `examples/complete_cv_example.py`
  66. feature_filename (`str` or `os.PathLike`):
  67. The filepath of a single feature example script. The contents of this script are checked to see if they
  68. exist in `base_filename`
  69. parser_only (`bool`):
  70. Whether to compare only the `main()` sections in both files, or to compare the contents of
  71. `training_loop()`
  72. secondary_filename (`str`, *optional*):
  73. A potential secondary filepath that should be included in the check. This function extracts the base
  74. functionalities off of "examples/nlp_example.py", so if `base_filename` is a script other than
  75. `complete_nlp_example.py`, the template script should be included here. Such as `examples/cv_example.py`
  76. """
  77. with open(base_filename) as f:
  78. base_file_contents = f.readlines()
  79. with open(os.path.abspath(os.path.join("examples", "nlp_example.py"))) as f:
  80. full_file_contents = f.readlines()
  81. with open(feature_filename) as f:
  82. feature_file_contents = f.readlines()
  83. if secondary_filename is not None:
  84. with open(secondary_filename) as f:
  85. secondary_file_contents = f.readlines()
  86. # This is our base, we remove all the code from here in our `full_filename` and `feature_filename` to find the new content
  87. if parser_only:
  88. base_file_func = clean_lines(get_function_contents_by_name(base_file_contents, "main"))
  89. full_file_func = clean_lines(get_function_contents_by_name(full_file_contents, "main"))
  90. feature_file_func = clean_lines(get_function_contents_by_name(feature_file_contents, "main"))
  91. if secondary_filename is not None:
  92. secondary_file_func = clean_lines(get_function_contents_by_name(secondary_file_contents, "main"))
  93. else:
  94. base_file_func = clean_lines(get_function_contents_by_name(base_file_contents, "training_function"))
  95. full_file_func = clean_lines(get_function_contents_by_name(full_file_contents, "training_function"))
  96. feature_file_func = clean_lines(get_function_contents_by_name(feature_file_contents, "training_function"))
  97. if secondary_filename is not None:
  98. secondary_file_func = clean_lines(
  99. get_function_contents_by_name(secondary_file_contents, "training_function")
  100. )
  101. _dl_line = "train_dataloader, eval_dataloader = get_dataloaders(accelerator, batch_size)\n"
  102. # Specific code in our script that differs from the full version, aka what is new
  103. new_feature_code = []
  104. passed_idxs = [] # We keep track of the idxs just in case it's a repeated statement
  105. it = iter(feature_file_func)
  106. for i in range(len(feature_file_func) - 1):
  107. if i not in passed_idxs:
  108. line = next(it)
  109. if (line not in full_file_func) and (line.lstrip() != _dl_line):
  110. if "TESTING_MOCKED_DATALOADERS" not in line:
  111. new_feature_code.append(line)
  112. passed_idxs.append(i)
  113. else:
  114. # Skip over the `config['num_epochs'] = 2` statement
  115. _ = next(it)
  116. # Extract out just the new parts from the full_file_training_func
  117. new_full_example_parts = []
  118. passed_idxs = [] # We keep track of the idxs just in case it's a repeated statement
  119. for i, line in enumerate(base_file_func):
  120. if i not in passed_idxs:
  121. if (line not in full_file_func) and (line.lstrip() != _dl_line):
  122. if "TESTING_MOCKED_DATALOADERS" not in line:
  123. new_full_example_parts.append(line)
  124. passed_idxs.append(i)
  125. # Finally, get the overall diff
  126. diff_from_example = [line for line in new_feature_code if line not in new_full_example_parts]
  127. if secondary_filename is not None:
  128. diff_from_two = [line for line in full_file_contents if line not in secondary_file_func]
  129. diff_from_example = [line for line in diff_from_example if line not in diff_from_two]
  130. return diff_from_example