gen_best_ep.py 4.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. import os
  2. from glob import glob
  3. import numpy as np
  4. from config import Config
  5. config = Config()
  6. eval_txts = sorted(glob('e_results/*_eval.txt'))
  7. print('eval_txts:', [_.split(os.sep)[-1] for _ in eval_txts])
  8. score_panel = {}
  9. sep = '&'
  10. metrics = ['sm', 'wfm', 'hce'] # we used HCE for DIS and wFm for others.
  11. if 'DIS5K' not in config.task:
  12. metrics.remove('hce')
  13. for metric in metrics:
  14. print('Metric:', metric)
  15. current_line_nums = []
  16. for idx_et, eval_txt in enumerate(eval_txts):
  17. with open(eval_txt, 'r') as f:
  18. lines = [l for l in f.readlines()[3:] if '.' in l]
  19. current_line_nums.append(len(lines))
  20. for idx_et, eval_txt in enumerate(eval_txts):
  21. with open(eval_txt, 'r') as f:
  22. lines = [l for l in f.readlines()[3:] if '.' in l]
  23. for idx_line, line in enumerate(lines[:min(current_line_nums)]): # Consist line numbers by the minimal result file.
  24. properties = line.strip().strip(sep).split(sep)
  25. dataset = properties[0].strip()
  26. ckpt = properties[1].strip()
  27. if int(ckpt.split('--epoch_')[-1].strip()) < 0:
  28. continue
  29. targe_idx = {
  30. 'sm': [5, 2, 2, 5, 5, 2],
  31. 'wfm': [3, 3, 8, 3, 3, 8],
  32. 'hce': [7, -1, -1, 7, 7, -1]
  33. }[metric][['DIS5K', 'COD', 'HRSOD', 'General', 'General-2K', 'Matting'].index(config.task)]
  34. if metric != 'hce':
  35. score_sm = float(properties[targe_idx].strip())
  36. else:
  37. score_sm = int(properties[targe_idx].strip().strip('.'))
  38. if idx_et == 0:
  39. score_panel[ckpt] = []
  40. score_panel[ckpt].append(score_sm)
  41. metrics_min = ['hce', 'mae']
  42. max_or_min = min if metric in metrics_min else max
  43. score_max = max_or_min(score_panel.values(), key=lambda x: np.sum(x))
  44. good_models = []
  45. for k, v in score_panel.items():
  46. if (np.sum(v) <= np.sum(score_max)) if metric in metrics_min else (np.sum(v) >= np.sum(score_max)):
  47. print(k, v)
  48. good_models.append(k)
  49. # Write
  50. with open(eval_txt, 'r') as f:
  51. lines = f.readlines()
  52. info4good_models = lines[:3]
  53. metric_names = [m.strip() for m in lines[1].strip().strip('&').split('&')[2:]]
  54. testset_mean_values = {metric_name: [] for metric_name in metric_names}
  55. for good_model in good_models:
  56. for idx_et, eval_txt in enumerate(eval_txts):
  57. with open(eval_txt, 'r') as f:
  58. lines = f.readlines()
  59. for line in lines:
  60. if set([good_model]) & set([_.strip() for _ in line.split(sep)]):
  61. info4good_models.append(line)
  62. metric_scores = [float(m.strip()) for m in line.strip().strip('&').split('&')[2:]]
  63. for idx_score, metric_score in enumerate(metric_scores):
  64. testset_mean_values[metric_names[idx_score]].append(metric_score)
  65. if 'DIS5K' in config.task:
  66. testset_mean_values_lst = ['{:<4}'.format(int(np.mean(v_lst[:-1]).round())) if name == 'HCE' else '{:.3f}'.format(np.mean(v_lst[:-1])).lstrip('0') for name, v_lst in testset_mean_values.items()] # [:-1] to remove DIS-VD
  67. sample_line_for_placing_mean_values = info4good_models[-2]
  68. numbers_placed_well = sample_line_for_placing_mean_values.replace(sample_line_for_placing_mean_values.split('&')[1].strip(), 'DIS-TEs').strip().split('&')[3:]
  69. for idx_number, (number_placed_well, testset_mean_value) in enumerate(zip(numbers_placed_well, testset_mean_values_lst)):
  70. numbers_placed_well[idx_number] = number_placed_well.replace(number_placed_well.strip(), testset_mean_value)
  71. testset_mean_line = '&'.join(sample_line_for_placing_mean_values.replace(sample_line_for_placing_mean_values.split('&')[1].strip(), 'DIS-TEs').split('&')[:3] + numbers_placed_well) + '\n'
  72. info4good_models.append(testset_mean_line)
  73. info4good_models.append(lines[-1])
  74. info = ''.join(info4good_models)
  75. print(info)
  76. with open(os.path.join('e_results', 'eval-{}_best_on_{}.txt'.format(config.task, metric)), 'w') as f:
  77. f.write(info + '\n')