compute_mean_std.py 1.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. # -*- coding: utf-8 -*-
  2. # @Time : 2019/12/7 14:46
  3. # @Author : zhoujun
  4. import numpy as np
  5. import cv2
  6. import os
  7. import random
  8. from tqdm import tqdm
  9. # calculate means and std
  10. train_txt_path = "./train_val_list.txt"
  11. CNum = 10000 # 挑选多少图片进行计算
  12. img_h, img_w = 640, 640
  13. imgs = np.zeros([img_w, img_h, 3, 1])
  14. means, stdevs = [], []
  15. with open(train_txt_path, "r") as f:
  16. lines = f.readlines()
  17. random.shuffle(lines) # shuffle , 随机挑选图片
  18. for i in tqdm(range(CNum)):
  19. img_path = lines[i].split("\t")[0]
  20. img = cv2.imread(img_path)
  21. img = cv2.resize(img, (img_h, img_w))
  22. img = img[:, :, :, np.newaxis]
  23. imgs = np.concatenate((imgs, img), axis=3)
  24. # print(i)
  25. imgs = imgs.astype(np.float32) / 255.0
  26. for i in tqdm(range(3)):
  27. pixels = imgs[:, :, i, :].ravel() # 拉成一行
  28. means.append(np.mean(pixels))
  29. stdevs.append(np.std(pixels))
  30. # cv2 读取的图像格式为BGR,PIL/Skimage读取到的都是RGB不用转
  31. means.reverse() # BGR --> RGB
  32. stdevs.reverse()
  33. print("normMean = {}".format(means))
  34. print("normStd = {}".format(stdevs))
  35. print("transforms.Normalize(normMean = {}, normStd = {})".format(means, stdevs))