data_loader.py 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. import numpy as np
  2. from paddle.vision.datasets import Cifar100
  3. from paddle.vision.transforms import Normalize
  4. import signal
  5. import os
  6. from paddle.io import Dataset, DataLoader, DistributedBatchSampler
  7. def term_mp(sig_num, frame):
  8. """kill all child processes"""
  9. pid = os.getpid()
  10. pgid = os.getpgid(os.getpid())
  11. print("main proc {} exit, kill process group " "{}".format(pid, pgid))
  12. os.killpg(pgid, signal.SIGKILL)
  13. return
  14. def build_dataloader(mode, batch_size=4, seed=None, num_workers=0, device="gpu:0"):
  15. normalize = Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], data_format="HWC")
  16. if mode.lower() == "train":
  17. dataset = Cifar100(mode=mode, transform=normalize)
  18. elif mode.lower() in ["test", "valid", "eval"]:
  19. dataset = Cifar100(mode="test", transform=normalize)
  20. else:
  21. raise ValueError(f"{mode} should be one of ['train', 'test']")
  22. # define batch sampler
  23. batch_sampler = DistributedBatchSampler(
  24. dataset=dataset, batch_size=batch_size, shuffle=False, drop_last=True
  25. )
  26. data_loader = DataLoader(
  27. dataset=dataset,
  28. batch_sampler=batch_sampler,
  29. places=device,
  30. num_workers=num_workers,
  31. return_list=True,
  32. use_shared_memory=False,
  33. )
  34. # support exit using ctrl+c
  35. signal.signal(signal.SIGINT, term_mp)
  36. signal.signal(signal.SIGTERM, term_mp)
  37. return data_loader
  38. # cifar100 = Cifar100(mode='train', transform=normalize)
  39. # data = cifar100[0]
  40. # image, label = data
  41. # reader = build_dataloader('train')
  42. # for idx, data in enumerate(reader):
  43. # print(idx, data[0].shape, data[1].shape)
  44. # if idx >= 10:
  45. # break