dataset.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. # -*- coding: utf-8 -*-
  2. # @Time : 2019/8/23 21:54
  3. # @Author : zhoujun
  4. import pathlib
  5. import os
  6. import cv2
  7. import numpy as np
  8. import scipy.io as sio
  9. from tqdm.auto import tqdm
  10. from base import BaseDataSet
  11. from utils import order_points_clockwise, get_datalist, load, expand_polygon
  12. class ICDAR2015Dataset(BaseDataSet):
  13. def __init__(
  14. self,
  15. data_path: str,
  16. img_mode,
  17. pre_processes,
  18. filter_keys,
  19. ignore_tags,
  20. transform=None,
  21. **kwargs,
  22. ):
  23. super().__init__(
  24. data_path, img_mode, pre_processes, filter_keys, ignore_tags, transform
  25. )
  26. def load_data(self, data_path: str) -> list:
  27. data_list = get_datalist(data_path)
  28. t_data_list = []
  29. for img_path, label_path in data_list:
  30. data = self._get_annotation(label_path)
  31. if len(data["text_polys"]) > 0:
  32. item = {"img_path": img_path, "img_name": pathlib.Path(img_path).stem}
  33. item.update(data)
  34. t_data_list.append(item)
  35. else:
  36. print("there is no suit bbox in {}".format(label_path))
  37. return t_data_list
  38. def _get_annotation(self, label_path: str) -> dict:
  39. boxes = []
  40. texts = []
  41. ignores = []
  42. with open(label_path, encoding="utf-8", mode="r") as f:
  43. for line in f.readlines():
  44. params = line.strip().strip("\ufeff").strip("\xef\xbb\xbf").split(",")
  45. try:
  46. box = order_points_clockwise(
  47. np.array(list(map(float, params[:8]))).reshape(-1, 2)
  48. )
  49. if cv2.contourArea(box) > 0:
  50. boxes.append(box)
  51. label = params[8]
  52. texts.append(label)
  53. ignores.append(label in self.ignore_tags)
  54. except:
  55. print("load label failed on {}".format(label_path))
  56. data = {
  57. "text_polys": np.array(boxes),
  58. "texts": texts,
  59. "ignore_tags": ignores,
  60. }
  61. return data
  62. class DetDataset(BaseDataSet):
  63. def __init__(
  64. self,
  65. data_path: str,
  66. img_mode,
  67. pre_processes,
  68. filter_keys,
  69. ignore_tags,
  70. transform=None,
  71. **kwargs,
  72. ):
  73. self.load_char_annotation = kwargs["load_char_annotation"]
  74. self.expand_one_char = kwargs["expand_one_char"]
  75. super().__init__(
  76. data_path, img_mode, pre_processes, filter_keys, ignore_tags, transform
  77. )
  78. def load_data(self, data_path: str) -> list:
  79. """
  80. 从json文件中读取出 文本行的坐标和gt,字符的坐标和gt
  81. :param data_path:
  82. :return:
  83. """
  84. data_list = []
  85. for path in data_path:
  86. content = load(path)
  87. for gt in tqdm(content["data_list"], desc="read file {}".format(path)):
  88. img_path = os.path.join(content["data_root"], gt["img_name"])
  89. polygons = []
  90. texts = []
  91. illegibility_list = []
  92. language_list = []
  93. for annotation in gt["annotations"]:
  94. if len(annotation["polygon"]) == 0 or len(annotation["text"]) == 0:
  95. continue
  96. if len(annotation["text"]) > 1 and self.expand_one_char:
  97. annotation["polygon"] = expand_polygon(annotation["polygon"])
  98. polygons.append(annotation["polygon"])
  99. texts.append(annotation["text"])
  100. illegibility_list.append(annotation["illegibility"])
  101. language_list.append(annotation["language"])
  102. if self.load_char_annotation:
  103. for char_annotation in annotation["chars"]:
  104. if (
  105. len(char_annotation["polygon"]) == 0
  106. or len(char_annotation["char"]) == 0
  107. ):
  108. continue
  109. polygons.append(char_annotation["polygon"])
  110. texts.append(char_annotation["char"])
  111. illegibility_list.append(char_annotation["illegibility"])
  112. language_list.append(char_annotation["language"])
  113. data_list.append(
  114. {
  115. "img_path": img_path,
  116. "img_name": gt["img_name"],
  117. "text_polys": np.array(polygons),
  118. "texts": texts,
  119. "ignore_tags": illegibility_list,
  120. }
  121. )
  122. return data_list
  123. class SynthTextDataset(BaseDataSet):
  124. def __init__(
  125. self,
  126. data_path: str,
  127. img_mode,
  128. pre_processes,
  129. filter_keys,
  130. transform=None,
  131. **kwargs,
  132. ):
  133. self.transform = transform
  134. self.dataRoot = pathlib.Path(data_path)
  135. if not self.dataRoot.exists():
  136. raise FileNotFoundError("Dataset folder is not exist.")
  137. self.targetFilePath = self.dataRoot / "gt.mat"
  138. if not self.targetFilePath.exists():
  139. raise FileExistsError("Target file is not exist.")
  140. targets = {}
  141. sio.loadmat(
  142. self.targetFilePath,
  143. targets,
  144. squeeze_me=True,
  145. struct_as_record=False,
  146. variable_names=["imnames", "wordBB", "txt"],
  147. )
  148. self.imageNames = targets["imnames"]
  149. self.wordBBoxes = targets["wordBB"]
  150. self.transcripts = targets["txt"]
  151. super().__init__(data_path, img_mode, pre_processes, filter_keys, transform)
  152. def load_data(self, data_path: str) -> list:
  153. t_data_list = []
  154. for imageName, wordBBoxes, texts in zip(
  155. self.imageNames, self.wordBBoxes, self.transcripts
  156. ):
  157. item = {}
  158. wordBBoxes = (
  159. np.expand_dims(wordBBoxes, axis=2)
  160. if (wordBBoxes.ndim == 2)
  161. else wordBBoxes
  162. )
  163. _, _, numOfWords = wordBBoxes.shape
  164. text_polys = wordBBoxes.reshape(
  165. [8, numOfWords], order="F"
  166. ).T # num_words * 8
  167. text_polys = text_polys.reshape(numOfWords, 4, 2) # num_of_words * 4 * 2
  168. transcripts = [word for line in texts for word in line.split()]
  169. if numOfWords != len(transcripts):
  170. continue
  171. item["img_path"] = str(self.dataRoot / imageName)
  172. item["img_name"] = (self.dataRoot / imageName).stem
  173. item["text_polys"] = text_polys
  174. item["texts"] = transcripts
  175. item["ignore_tags"] = [x in self.ignore_tags for x in transcripts]
  176. t_data_list.append(item)
  177. return t_data_list