process_grid.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. # Copyright (c) 2022 Zhipu.AI
  2. import glob
  3. import os
  4. import statistics
  5. import sys
  6. import json
  7. path_pattern = sys.argv[1]
  8. target_type = sys.argv[2]
  9. best_value, best_result, best_name = None, None, None
  10. mean_result = {}
  11. print(path_pattern)
  12. for dir_path in glob.glob(path_pattern, recursive=True):
  13. entry = os.path.basename(dir_path)
  14. valid_result = None
  15. test_found = os.path.exists(os.path.join(dir_path, 'test_results.json'))
  16. valid_path = os.path.join(dir_path, 'results.json')
  17. if os.path.exists(valid_path):
  18. print(entry)
  19. with open(valid_path, encoding='utf-8') as file:
  20. valid_result = json.load(file)
  21. else:
  22. print(f'{entry} no validation results')
  23. continue
  24. if not test_found:
  25. print(f'{entry} not tested yet')
  26. if target_type == 'max':
  27. metric = sys.argv[3]
  28. metric_value = valid_result[metric]
  29. if best_value is None or metric_value > best_value:
  30. best_value = metric_value
  31. best_result = valid_result
  32. best_name = entry
  33. elif target_type == 'mean' or target_type == 'median':
  34. if mean_result:
  35. for metric, value in valid_result.items():
  36. if metric not in ['type', 'epoch']:
  37. mean_result[metric].append(value)
  38. else:
  39. mean_result = {
  40. metric: [value]
  41. for metric, value in valid_result.items()
  42. if metric not in ['type', 'epoch']
  43. }
  44. if target_type == 'max':
  45. print(f'Best result found at {best_name}: {best_result}')
  46. elif target_type == 'mean':
  47. mean_result = {
  48. metric: sum(value) / len(value)
  49. for metric, value in mean_result.items()
  50. }
  51. print(f'Mean result {mean_result}')
  52. elif target_type == 'median':
  53. mean_result = {
  54. metric: statistics.median(value)
  55. for metric, value in mean_result.items()
  56. }
  57. print(f'Mean result {mean_result}')