load_cifar.py 1.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940
  1. import pickle as p
  2. import numpy as np
  3. from PIL import Image
  4. def load_CIFAR_batch(filename):
  5. """load single batch of cifar"""
  6. with open(filename, "rb") as f:
  7. datadict = p.load(f, encoding="bytes")
  8. # 以字典的形式取出数据
  9. X = datadict[b"data"]
  10. Y = datadict[b"fine_labels"]
  11. try:
  12. X = X.reshape(10000, 3, 32, 32)
  13. except:
  14. X = X.reshape(50000, 3, 32, 32)
  15. Y = np.array(Y)
  16. print(Y.shape)
  17. return X, Y
  18. if __name__ == "__main__":
  19. mode = "train"
  20. imgX, imgY = load_CIFAR_batch(f"./cifar-100-python/{mode}")
  21. with open(f"./cifar-100-python/{mode}_imgs/img_label.txt", "a+") as f:
  22. for i in range(imgY.shape[0]):
  23. f.write("img" + str(i) + " " + str(imgY[i]) + "\n")
  24. for i in range(imgX.shape[0]):
  25. imgs = imgX[i]
  26. img0 = imgs[0]
  27. img1 = imgs[1]
  28. img2 = imgs[2]
  29. i0 = Image.fromarray(img0)
  30. i1 = Image.fromarray(img1)
  31. i2 = Image.fromarray(img2)
  32. img = Image.merge("RGB", (i0, i1, i2))
  33. name = "img" + str(i) + ".png"
  34. img.save(f"./cifar-100-python/{mode}_imgs/" + name, "png")
  35. print("save successfully!")