operators.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536
  1. """
  2. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. from __future__ import unicode_literals
  20. import sys
  21. import cv2
  22. import numpy as np
  23. import math
  24. from PIL import Image
  25. from paddle import get_device
  26. class DecodeImage(object):
  27. """decode image"""
  28. def __init__(
  29. self, img_mode="RGB", channel_first=False, ignore_orientation=False, **kwargs
  30. ):
  31. self.img_mode = img_mode
  32. self.channel_first = channel_first
  33. self.ignore_orientation = ignore_orientation
  34. def __call__(self, data):
  35. img = data["image"]
  36. assert type(img) is bytes and len(img) > 0, "invalid input 'img' in DecodeImage"
  37. img = np.frombuffer(img, dtype="uint8")
  38. if self.img_mode == "GRAY":
  39. # For GRAY mode, decode directly to a single-channel grayscale image.
  40. decode_flag = cv2.IMREAD_GRAYSCALE
  41. else:
  42. # For RGB mode, decode to a 3-channel color image.
  43. decode_flag = cv2.IMREAD_COLOR
  44. if self.ignore_orientation:
  45. decode_flag |= cv2.IMREAD_IGNORE_ORIENTATION
  46. img = cv2.imdecode(img, decode_flag)
  47. if img is None:
  48. return None
  49. if self.img_mode == "GRAY":
  50. img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
  51. elif self.img_mode == "RGB":
  52. assert img.shape[2] == 3, "invalid shape of image[%s]" % (img.shape)
  53. img = img[:, :, ::-1]
  54. if self.channel_first:
  55. img = img.transpose((2, 0, 1))
  56. data["image"] = img
  57. return data
  58. class NormalizeImage(object):
  59. """normalize image such as subtract mean, divide std"""
  60. def __init__(self, scale=None, mean=None, std=None, order="chw", **kwargs):
  61. if isinstance(scale, str):
  62. scale = eval(scale)
  63. self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
  64. mean = mean if mean is not None else [0.485, 0.456, 0.406]
  65. std = std if std is not None else [0.229, 0.224, 0.225]
  66. shape = (3, 1, 1) if order == "chw" else (1, 1, 3)
  67. self.mean = np.array(mean).reshape(shape).astype("float32")
  68. self.std = np.array(std).reshape(shape).astype("float32")
  69. def __call__(self, data):
  70. img = data["image"]
  71. from PIL import Image
  72. if isinstance(img, Image.Image):
  73. img = np.array(img)
  74. assert isinstance(img, np.ndarray), "invalid input 'img' in NormalizeImage"
  75. data["image"] = (img.astype("float32") * self.scale - self.mean) / self.std
  76. return data
  77. class ToCHWImage(object):
  78. """convert hwc image to chw image"""
  79. def __init__(self, **kwargs):
  80. pass
  81. def __call__(self, data):
  82. img = data["image"]
  83. from PIL import Image
  84. if isinstance(img, Image.Image):
  85. img = np.array(img)
  86. data["image"] = img.transpose((2, 0, 1))
  87. return data
  88. class Fasttext(object):
  89. def __init__(self, path="None", **kwargs):
  90. import fasttext
  91. self.fast_model = fasttext.load_model(path)
  92. def __call__(self, data):
  93. label = data["label"]
  94. fast_label = self.fast_model[label]
  95. data["fast_label"] = fast_label
  96. return data
  97. class KeepKeys(object):
  98. def __init__(self, keep_keys, **kwargs):
  99. self.keep_keys = keep_keys
  100. def __call__(self, data):
  101. data_list = []
  102. for key in self.keep_keys:
  103. data_list.append(data[key])
  104. return data_list
  105. class Pad(object):
  106. def __init__(self, size=None, size_div=32, **kwargs):
  107. if size is not None and not isinstance(size, (int, list, tuple)):
  108. raise TypeError(
  109. "Type of target_size is invalid. Now is {}".format(type(size))
  110. )
  111. if isinstance(size, int):
  112. size = [size, size]
  113. self.size = size
  114. self.size_div = size_div
  115. def __call__(self, data):
  116. img = data["image"]
  117. img_h, img_w = img.shape[0], img.shape[1]
  118. if self.size:
  119. resize_h2, resize_w2 = self.size
  120. assert (
  121. img_h < resize_h2 and img_w < resize_w2
  122. ), "(h, w) of target size should be greater than (img_h, img_w)"
  123. else:
  124. resize_h2 = max(
  125. int(math.ceil(img.shape[0] / self.size_div) * self.size_div),
  126. self.size_div,
  127. )
  128. resize_w2 = max(
  129. int(math.ceil(img.shape[1] / self.size_div) * self.size_div),
  130. self.size_div,
  131. )
  132. img = cv2.copyMakeBorder(
  133. img,
  134. 0,
  135. resize_h2 - img_h,
  136. 0,
  137. resize_w2 - img_w,
  138. cv2.BORDER_CONSTANT,
  139. value=0,
  140. )
  141. data["image"] = img
  142. return data
  143. class Resize(object):
  144. def __init__(self, size=(640, 640), **kwargs):
  145. self.size = size
  146. def resize_image(self, img):
  147. resize_h, resize_w = self.size
  148. ori_h, ori_w = img.shape[:2] # (h, w, c)
  149. ratio_h = float(resize_h) / ori_h
  150. ratio_w = float(resize_w) / ori_w
  151. img = cv2.resize(img, (int(resize_w), int(resize_h)))
  152. return img, [ratio_h, ratio_w]
  153. def __call__(self, data):
  154. img = data["image"]
  155. if "polys" in data:
  156. text_polys = data["polys"]
  157. img_resize, [ratio_h, ratio_w] = self.resize_image(img)
  158. if "polys" in data:
  159. new_boxes = []
  160. for box in text_polys:
  161. new_box = []
  162. for cord in box:
  163. new_box.append([cord[0] * ratio_w, cord[1] * ratio_h])
  164. new_boxes.append(new_box)
  165. data["polys"] = np.array(new_boxes, dtype=np.float32)
  166. data["image"] = img_resize
  167. return data
  168. class DetResizeForTest(object):
  169. def __init__(self, **kwargs):
  170. super(DetResizeForTest, self).__init__()
  171. self.resize_type = 0
  172. self.keep_ratio = False
  173. if "image_shape" in kwargs:
  174. self.image_shape = kwargs["image_shape"]
  175. self.resize_type = 1
  176. if "keep_ratio" in kwargs:
  177. self.keep_ratio = kwargs["keep_ratio"]
  178. elif "limit_side_len" in kwargs:
  179. self.limit_side_len = kwargs["limit_side_len"]
  180. self.limit_type = kwargs.get("limit_type", "min")
  181. elif "resize_long" in kwargs:
  182. self.resize_type = 2
  183. self.resize_long = kwargs.get("resize_long", 960)
  184. else:
  185. self.limit_side_len = 736
  186. self.limit_type = "min"
  187. def __call__(self, data):
  188. img = data["image"]
  189. src_h, src_w, _ = img.shape
  190. if sum([src_h, src_w]) < 64:
  191. img = self.image_padding(img)
  192. if self.resize_type == 0:
  193. # img, shape = self.resize_image_type0(img)
  194. img, [ratio_h, ratio_w] = self.resize_image_type0(img)
  195. elif self.resize_type == 2:
  196. img, [ratio_h, ratio_w] = self.resize_image_type2(img)
  197. else:
  198. # img, shape = self.resize_image_type1(img)
  199. img, [ratio_h, ratio_w] = self.resize_image_type1(img)
  200. data["image"] = img
  201. data["shape"] = np.array([src_h, src_w, ratio_h, ratio_w])
  202. if "iluvatar_gpu" in get_device():
  203. data["shape"] = np.array([src_h, src_w, ratio_h, ratio_w]).astype(
  204. np.float32
  205. )
  206. return data
  207. def image_padding(self, im, value=0):
  208. h, w, c = im.shape
  209. im_pad = np.zeros((max(32, h), max(32, w), c), np.uint8) + value
  210. im_pad[:h, :w, :] = im
  211. return im_pad
  212. def resize_image_type1(self, img):
  213. resize_h, resize_w = self.image_shape
  214. ori_h, ori_w = img.shape[:2] # (h, w, c)
  215. if self.keep_ratio is True:
  216. resize_w = ori_w * resize_h / ori_h
  217. N = math.ceil(resize_w / 32)
  218. resize_w = N * 32
  219. ratio_h = float(resize_h) / ori_h
  220. ratio_w = float(resize_w) / ori_w
  221. img = cv2.resize(img, (int(resize_w), int(resize_h)))
  222. # return img, np.array([ori_h, ori_w])
  223. return img, [ratio_h, ratio_w]
  224. def resize_image_type0(self, img):
  225. """
  226. resize image to a size multiple of 32 which is required by the network
  227. args:
  228. img(array): array with shape [h, w, c]
  229. return(tuple):
  230. img, (ratio_h, ratio_w)
  231. """
  232. limit_side_len = self.limit_side_len
  233. h, w, c = img.shape
  234. # limit the max side
  235. if self.limit_type == "max":
  236. if max(h, w) > limit_side_len:
  237. if h > w:
  238. ratio = float(limit_side_len) / h
  239. else:
  240. ratio = float(limit_side_len) / w
  241. else:
  242. ratio = 1.0
  243. elif self.limit_type == "min":
  244. if min(h, w) < limit_side_len:
  245. if h < w:
  246. ratio = float(limit_side_len) / h
  247. else:
  248. ratio = float(limit_side_len) / w
  249. else:
  250. ratio = 1.0
  251. elif self.limit_type == "resize_long":
  252. ratio = float(limit_side_len) / max(h, w)
  253. else:
  254. raise Exception("not support limit type, image ")
  255. resize_h = int(h * ratio)
  256. resize_w = int(w * ratio)
  257. resize_h = max(int(round(resize_h / 32) * 32), 32)
  258. resize_w = max(int(round(resize_w / 32) * 32), 32)
  259. try:
  260. if int(resize_w) <= 0 or int(resize_h) <= 0:
  261. return None, (None, None)
  262. img = cv2.resize(img, (int(resize_w), int(resize_h)))
  263. except:
  264. print(img.shape, resize_w, resize_h)
  265. sys.exit(0)
  266. ratio_h = resize_h / float(h)
  267. ratio_w = resize_w / float(w)
  268. return img, [ratio_h, ratio_w]
  269. def resize_image_type2(self, img):
  270. h, w, _ = img.shape
  271. resize_w = w
  272. resize_h = h
  273. if resize_h > resize_w:
  274. ratio = float(self.resize_long) / resize_h
  275. else:
  276. ratio = float(self.resize_long) / resize_w
  277. resize_h = int(resize_h * ratio)
  278. resize_w = int(resize_w * ratio)
  279. max_stride = 128
  280. resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
  281. resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
  282. img = cv2.resize(img, (int(resize_w), int(resize_h)))
  283. ratio_h = resize_h / float(h)
  284. ratio_w = resize_w / float(w)
  285. return img, [ratio_h, ratio_w]
  286. class E2EResizeForTest(object):
  287. def __init__(self, **kwargs):
  288. super(E2EResizeForTest, self).__init__()
  289. self.max_side_len = kwargs["max_side_len"]
  290. self.valid_set = kwargs["valid_set"]
  291. def __call__(self, data):
  292. img = data["image"]
  293. src_h, src_w, _ = img.shape
  294. if self.valid_set == "totaltext":
  295. im_resized, [ratio_h, ratio_w] = self.resize_image_for_totaltext(
  296. img, max_side_len=self.max_side_len
  297. )
  298. else:
  299. im_resized, (ratio_h, ratio_w) = self.resize_image(
  300. img, max_side_len=self.max_side_len
  301. )
  302. data["image"] = im_resized
  303. data["shape"] = np.array([src_h, src_w, ratio_h, ratio_w])
  304. return data
  305. def resize_image_for_totaltext(self, im, max_side_len=512):
  306. h, w, _ = im.shape
  307. resize_w = w
  308. resize_h = h
  309. ratio = 1.25
  310. if h * ratio > max_side_len:
  311. ratio = float(max_side_len) / resize_h
  312. resize_h = int(resize_h * ratio)
  313. resize_w = int(resize_w * ratio)
  314. max_stride = 128
  315. resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
  316. resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
  317. im = cv2.resize(im, (int(resize_w), int(resize_h)))
  318. ratio_h = resize_h / float(h)
  319. ratio_w = resize_w / float(w)
  320. return im, (ratio_h, ratio_w)
  321. def resize_image(self, im, max_side_len=512):
  322. """
  323. resize image to a size multiple of max_stride which is required by the network
  324. :param im: the resized image
  325. :param max_side_len: limit of max image size to avoid out of memory in gpu
  326. :return: the resized image and the resize ratio
  327. """
  328. h, w, _ = im.shape
  329. resize_w = w
  330. resize_h = h
  331. # Fix the longer side
  332. if resize_h > resize_w:
  333. ratio = float(max_side_len) / resize_h
  334. else:
  335. ratio = float(max_side_len) / resize_w
  336. resize_h = int(resize_h * ratio)
  337. resize_w = int(resize_w * ratio)
  338. max_stride = 128
  339. resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
  340. resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
  341. im = cv2.resize(im, (int(resize_w), int(resize_h)))
  342. ratio_h = resize_h / float(h)
  343. ratio_w = resize_w / float(w)
  344. return im, (ratio_h, ratio_w)
  345. class KieResize(object):
  346. def __init__(self, **kwargs):
  347. super(KieResize, self).__init__()
  348. self.max_side, self.min_side = kwargs["img_scale"][0], kwargs["img_scale"][1]
  349. def __call__(self, data):
  350. img = data["image"]
  351. points = data["points"]
  352. src_h, src_w, _ = img.shape
  353. (
  354. im_resized,
  355. scale_factor,
  356. [ratio_h, ratio_w],
  357. [new_h, new_w],
  358. ) = self.resize_image(img)
  359. resize_points = self.resize_boxes(img, points, scale_factor)
  360. data["ori_image"] = img
  361. data["ori_boxes"] = points
  362. data["points"] = resize_points
  363. data["image"] = im_resized
  364. data["shape"] = np.array([new_h, new_w])
  365. return data
  366. def resize_image(self, img):
  367. norm_img = np.zeros([1024, 1024, 3], dtype="float32")
  368. scale = [512, 1024]
  369. h, w = img.shape[:2]
  370. max_long_edge = max(scale)
  371. max_short_edge = min(scale)
  372. scale_factor = min(max_long_edge / max(h, w), max_short_edge / min(h, w))
  373. resize_w, resize_h = int(w * float(scale_factor) + 0.5), int(
  374. h * float(scale_factor) + 0.5
  375. )
  376. max_stride = 32
  377. resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
  378. resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
  379. im = cv2.resize(img, (resize_w, resize_h))
  380. new_h, new_w = im.shape[:2]
  381. w_scale = new_w / w
  382. h_scale = new_h / h
  383. scale_factor = np.array([w_scale, h_scale, w_scale, h_scale], dtype=np.float32)
  384. norm_img[:new_h, :new_w, :] = im
  385. return norm_img, scale_factor, [h_scale, w_scale], [new_h, new_w]
  386. def resize_boxes(self, im, points, scale_factor):
  387. points = points * scale_factor
  388. img_shape = im.shape[:2]
  389. points[:, 0::2] = np.clip(points[:, 0::2], 0, img_shape[1])
  390. points[:, 1::2] = np.clip(points[:, 1::2], 0, img_shape[0])
  391. return points
  392. class SRResize(object):
  393. def __init__(
  394. self,
  395. imgH=32,
  396. imgW=128,
  397. down_sample_scale=4,
  398. keep_ratio=False,
  399. min_ratio=1,
  400. mask=False,
  401. infer_mode=False,
  402. **kwargs,
  403. ):
  404. self.imgH = imgH
  405. self.imgW = imgW
  406. self.keep_ratio = keep_ratio
  407. self.min_ratio = min_ratio
  408. self.down_sample_scale = down_sample_scale
  409. self.mask = mask
  410. self.infer_mode = infer_mode
  411. def __call__(self, data):
  412. imgH = self.imgH
  413. imgW = self.imgW
  414. images_lr = data["image_lr"]
  415. transform2 = ResizeNormalize(
  416. (imgW // self.down_sample_scale, imgH // self.down_sample_scale)
  417. )
  418. images_lr = transform2(images_lr)
  419. data["img_lr"] = images_lr
  420. if self.infer_mode:
  421. return data
  422. images_HR = data["image_hr"]
  423. label_strs = data["label"]
  424. transform = ResizeNormalize((imgW, imgH))
  425. images_HR = transform(images_HR)
  426. data["img_hr"] = images_HR
  427. return data
  428. class ResizeNormalize(object):
  429. def __init__(self, size, interpolation=Image.BICUBIC):
  430. self.size = size
  431. self.interpolation = interpolation
  432. def __call__(self, img):
  433. img = img.resize(self.size, self.interpolation)
  434. img_numpy = np.array(img).astype("float32")
  435. img_numpy = img_numpy.transpose((2, 0, 1)) / 255
  436. return img_numpy
  437. class GrayImageChannelFormat(object):
  438. """
  439. format gray scale image's channel: (3,h,w) -> (1,h,w)
  440. Args:
  441. inverse: inverse gray image
  442. """
  443. def __init__(self, inverse=False, **kwargs):
  444. self.inverse = inverse
  445. def __call__(self, data):
  446. img = data["image"]
  447. img_single_channel = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  448. img_expanded = np.expand_dims(img_single_channel, 0)
  449. if self.inverse:
  450. data["image"] = np.abs(img_expanded - 1)
  451. else:
  452. data["image"] = img_expanded
  453. data["src_image"] = img
  454. return data