io_utils.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. import os
  2. import os.path as osp
  3. import glob
  4. from pathlib import Path
  5. import cv2
  6. import numpy as np
  7. import json
  8. IMG_EXT = ['.bmp', '.jpg', '.png', '.jpeg']
  9. # 兼容 numpy 2.x(移除了 bool8, float_, int_)
  10. try:
  11. NP_BOOL_TYPES = (np.bool_, np.bool8)
  12. except AttributeError:
  13. # numpy 2.x 移除了 bool8,只使用 bool_
  14. NP_BOOL_TYPES = (np.bool_,)
  15. try:
  16. NP_FLOAT_TYPES = (np.float_, np.float16, np.float32, np.float64)
  17. except AttributeError:
  18. # numpy 2.x 移除了 float_,使用 float64 代替
  19. NP_FLOAT_TYPES = (np.float64, np.float16, np.float32, np.float64)
  20. try:
  21. NP_INT_TYPES = (np.int_, np.int8, np.int16, np.int32, np.int64, np.uint, np.uint8, np.uint16, np.uint32, np.uint64)
  22. except AttributeError:
  23. # numpy 2.x 移除了 int_ 和 uint,使用 int64 和 uint64 代替
  24. NP_INT_TYPES = (np.int64, np.int8, np.int16, np.int32, np.int64, np.uint64, np.uint8, np.uint16, np.uint32, np.uint64)
  25. # https://stackoverflow.com/questions/26646362/numpy-array-is-not-json-serializable
  26. class NumpyEncoder(json.JSONEncoder):
  27. def default(self, obj):
  28. if isinstance(obj, np.ndarray):
  29. return obj.tolist()
  30. elif isinstance(obj, np.ScalarType):
  31. if isinstance(obj, NP_BOOL_TYPES):
  32. return bool(obj)
  33. elif isinstance(obj, NP_FLOAT_TYPES):
  34. return float(obj)
  35. elif isinstance(obj, NP_INT_TYPES):
  36. return int(obj)
  37. return json.JSONEncoder.default(self, obj)
  38. def find_all_imgs(img_dir, abs_path=False):
  39. imglist = list()
  40. for filep in glob.glob(osp.join(img_dir, "*")):
  41. filename = osp.basename(filep)
  42. file_suffix = Path(filename).suffix
  43. if file_suffix.lower() not in IMG_EXT:
  44. continue
  45. if abs_path:
  46. imglist.append(filep)
  47. else:
  48. imglist.append(filename)
  49. return imglist
  50. imread = lambda imgpath, read_type=cv2.IMREAD_COLOR: cv2.imdecode(np.fromfile(imgpath, dtype=np.uint8), read_type)
  51. # def imread(imgpath, read_type=cv2.IMREAD_COLOR):
  52. # img = cv2.imdecode(np.fromfile(imgpath, dtype=np.uint8), read_type)
  53. # return img
  54. def imwrite(img_path, img, ext='.png'):
  55. suffix = Path(img_path).suffix
  56. if suffix != '':
  57. img_path = img_path.replace(suffix, ext)
  58. else:
  59. img_path += ext
  60. cv2.imencode(ext, img)[1].tofile(img_path)