script.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. from collections import namedtuple
  4. from . import rrc_evaluation_funcs
  5. import Polygon as plg
  6. import numpy as np
  7. def default_evaluation_params():
  8. """
  9. default_evaluation_params: Default parameters to use for the validation and evaluation.
  10. """
  11. return {
  12. "IOU_CONSTRAINT": 0.5,
  13. "AREA_PRECISION_CONSTRAINT": 0.5,
  14. "GT_SAMPLE_NAME_2_ID": "gt_img_([0-9]+).txt",
  15. "DET_SAMPLE_NAME_2_ID": "res_img_([0-9]+).txt",
  16. "LTRB": False, # LTRB:2points(left,top,right,bottom) or 4 points(x1,y1,x2,y2,x3,y3,x4,y4)
  17. "CRLF": False, # Lines are delimited by Windows CRLF format
  18. "CONFIDENCES": False, # Detections must include confidence value. AP will be calculated
  19. "PER_SAMPLE_RESULTS": True, # Generate per sample results and produce data for visualization
  20. }
  21. def validate_data(gtFilePath, submFilePath, evaluationParams):
  22. """
  23. Method validate_data: validates that all files in the results folder are correct (have the correct name contents).
  24. Validates also that there are no missing files in the folder.
  25. If some error detected, the method raises the error
  26. """
  27. gt = rrc_evaluation_funcs.load_folder_file(
  28. gtFilePath, evaluationParams["GT_SAMPLE_NAME_2_ID"]
  29. )
  30. subm = rrc_evaluation_funcs.load_folder_file(
  31. submFilePath, evaluationParams["DET_SAMPLE_NAME_2_ID"], True
  32. )
  33. # Validate format of GroundTruth
  34. for k in gt:
  35. rrc_evaluation_funcs.validate_lines_in_file(
  36. k, gt[k], evaluationParams["CRLF"], evaluationParams["LTRB"], True
  37. )
  38. # Validate format of results
  39. for k in subm:
  40. if (k in gt) == False:
  41. raise Exception("The sample %s not present in GT" % k)
  42. rrc_evaluation_funcs.validate_lines_in_file(
  43. k,
  44. subm[k],
  45. evaluationParams["CRLF"],
  46. evaluationParams["LTRB"],
  47. False,
  48. evaluationParams["CONFIDENCES"],
  49. )
  50. def evaluate_method(gtFilePath, submFilePath, evaluationParams):
  51. """
  52. Method evaluate_method: evaluate method and returns the results
  53. Results. Dictionary with the following values:
  54. - method (required) Global method metrics. Ex: { 'Precision':0.8,'Recall':0.9 }
  55. - samples (optional) Per sample metrics. Ex: {'sample1' : { 'Precision':0.8,'Recall':0.9 } , 'sample2' : { 'Precision':0.8,'Recall':0.9 }
  56. """
  57. def polygon_from_points(points):
  58. """
  59. Returns a Polygon object to use with the Polygon2 class from a list of 8 points: x1,y1,x2,y2,x3,y3,x4,y4
  60. """
  61. resBoxes = np.empty([1, 8], dtype="int32")
  62. resBoxes[0, 0] = int(points[0])
  63. resBoxes[0, 4] = int(points[1])
  64. resBoxes[0, 1] = int(points[2])
  65. resBoxes[0, 5] = int(points[3])
  66. resBoxes[0, 2] = int(points[4])
  67. resBoxes[0, 6] = int(points[5])
  68. resBoxes[0, 3] = int(points[6])
  69. resBoxes[0, 7] = int(points[7])
  70. pointMat = resBoxes[0].reshape([2, 4]).T
  71. return plg.Polygon(pointMat)
  72. def rectangle_to_polygon(rect):
  73. resBoxes = np.empty([1, 8], dtype="int32")
  74. resBoxes[0, 0] = int(rect.xmin)
  75. resBoxes[0, 4] = int(rect.ymax)
  76. resBoxes[0, 1] = int(rect.xmin)
  77. resBoxes[0, 5] = int(rect.ymin)
  78. resBoxes[0, 2] = int(rect.xmax)
  79. resBoxes[0, 6] = int(rect.ymin)
  80. resBoxes[0, 3] = int(rect.xmax)
  81. resBoxes[0, 7] = int(rect.ymax)
  82. pointMat = resBoxes[0].reshape([2, 4]).T
  83. return plg.Polygon(pointMat)
  84. def rectangle_to_points(rect):
  85. points = [
  86. int(rect.xmin),
  87. int(rect.ymax),
  88. int(rect.xmax),
  89. int(rect.ymax),
  90. int(rect.xmax),
  91. int(rect.ymin),
  92. int(rect.xmin),
  93. int(rect.ymin),
  94. ]
  95. return points
  96. def get_union(pD, pG):
  97. areaA = pD.area()
  98. areaB = pG.area()
  99. return areaA + areaB - get_intersection(pD, pG)
  100. def get_intersection_over_union(pD, pG):
  101. try:
  102. return get_intersection(pD, pG) / get_union(pD, pG)
  103. except:
  104. return 0
  105. def get_intersection(pD, pG):
  106. pInt = pD & pG
  107. if len(pInt) == 0:
  108. return 0
  109. return pInt.area()
  110. def compute_ap(confList, matchList, numGtCare):
  111. correct = 0
  112. AP = 0
  113. if len(confList) > 0:
  114. confList = np.array(confList)
  115. matchList = np.array(matchList)
  116. sorted_ind = np.argsort(-confList)
  117. confList = confList[sorted_ind]
  118. matchList = matchList[sorted_ind]
  119. for n in range(len(confList)):
  120. match = matchList[n]
  121. if match:
  122. correct += 1
  123. AP += float(correct) / (n + 1)
  124. if numGtCare > 0:
  125. AP /= numGtCare
  126. return AP
  127. perSampleMetrics = {}
  128. matchedSum = 0
  129. Rectangle = namedtuple("Rectangle", "xmin ymin xmax ymax")
  130. gt = rrc_evaluation_funcs.load_folder_file(
  131. gtFilePath, evaluationParams["GT_SAMPLE_NAME_2_ID"]
  132. )
  133. subm = rrc_evaluation_funcs.load_folder_file(
  134. submFilePath, evaluationParams["DET_SAMPLE_NAME_2_ID"], True
  135. )
  136. numGlobalCareGt = 0
  137. numGlobalCareDet = 0
  138. arrGlobalConfidences = []
  139. arrGlobalMatches = []
  140. for resFile in gt:
  141. gtFile = gt[resFile] # rrc_evaluation_funcs.decode_utf8(gt[resFile])
  142. recall = 0
  143. precision = 0
  144. hmean = 0
  145. detMatched = 0
  146. iouMat = np.empty([1, 1])
  147. gtPols = []
  148. detPols = []
  149. gtPolPoints = []
  150. detPolPoints = []
  151. # Array of Ground Truth Polygons' keys marked as don't Care
  152. gtDontCarePolsNum = []
  153. # Array of Detected Polygons' matched with a don't Care GT
  154. detDontCarePolsNum = []
  155. pairs = []
  156. detMatchedNums = []
  157. arrSampleConfidences = []
  158. arrSampleMatch = []
  159. sampleAP = 0
  160. evaluationLog = ""
  161. (
  162. pointsList,
  163. _,
  164. transcriptionsList,
  165. ) = rrc_evaluation_funcs.get_tl_line_values_from_file_contents(
  166. gtFile, evaluationParams["CRLF"], evaluationParams["LTRB"], True, False
  167. )
  168. for n in range(len(pointsList)):
  169. points = pointsList[n]
  170. transcription = transcriptionsList[n]
  171. dontCare = transcription == "###"
  172. if evaluationParams["LTRB"]:
  173. gtRect = Rectangle(*points)
  174. gtPol = rectangle_to_polygon(gtRect)
  175. else:
  176. gtPol = polygon_from_points(points)
  177. gtPols.append(gtPol)
  178. gtPolPoints.append(points)
  179. if dontCare:
  180. gtDontCarePolsNum.append(len(gtPols) - 1)
  181. evaluationLog += (
  182. "GT polygons: "
  183. + str(len(gtPols))
  184. + (
  185. " (" + str(len(gtDontCarePolsNum)) + " don't care)\n"
  186. if len(gtDontCarePolsNum) > 0
  187. else "\n"
  188. )
  189. )
  190. if resFile in subm:
  191. detFile = subm[resFile] # rrc_evaluation_funcs.decode_utf8(subm[resFile])
  192. (
  193. pointsList,
  194. confidencesList,
  195. _,
  196. ) = rrc_evaluation_funcs.get_tl_line_values_from_file_contents(
  197. detFile,
  198. evaluationParams["CRLF"],
  199. evaluationParams["LTRB"],
  200. False,
  201. evaluationParams["CONFIDENCES"],
  202. )
  203. for n in range(len(pointsList)):
  204. points = pointsList[n]
  205. if evaluationParams["LTRB"]:
  206. detRect = Rectangle(*points)
  207. detPol = rectangle_to_polygon(detRect)
  208. else:
  209. detPol = polygon_from_points(points)
  210. detPols.append(detPol)
  211. detPolPoints.append(points)
  212. if len(gtDontCarePolsNum) > 0:
  213. for dontCarePol in gtDontCarePolsNum:
  214. dontCarePol = gtPols[dontCarePol]
  215. intersected_area = get_intersection(dontCarePol, detPol)
  216. pdDimensions = detPol.area()
  217. precision = (
  218. 0 if pdDimensions == 0 else intersected_area / pdDimensions
  219. )
  220. if precision > evaluationParams["AREA_PRECISION_CONSTRAINT"]:
  221. detDontCarePolsNum.append(len(detPols) - 1)
  222. break
  223. evaluationLog += (
  224. "DET polygons: "
  225. + str(len(detPols))
  226. + (
  227. " (" + str(len(detDontCarePolsNum)) + " don't care)\n"
  228. if len(detDontCarePolsNum) > 0
  229. else "\n"
  230. )
  231. )
  232. if len(gtPols) > 0 and len(detPols) > 0:
  233. # Calculate IoU and precision matrixs
  234. outputShape = [len(gtPols), len(detPols)]
  235. iouMat = np.empty(outputShape)
  236. gtRectMat = np.zeros(len(gtPols), np.int8)
  237. detRectMat = np.zeros(len(detPols), np.int8)
  238. for gtNum in range(len(gtPols)):
  239. for detNum in range(len(detPols)):
  240. pG = gtPols[gtNum]
  241. pD = detPols[detNum]
  242. iouMat[gtNum, detNum] = get_intersection_over_union(pD, pG)
  243. for gtNum in range(len(gtPols)):
  244. for detNum in range(len(detPols)):
  245. if (
  246. gtRectMat[gtNum] == 0
  247. and detRectMat[detNum] == 0
  248. and gtNum not in gtDontCarePolsNum
  249. and detNum not in detDontCarePolsNum
  250. ):
  251. if (
  252. iouMat[gtNum, detNum]
  253. > evaluationParams["IOU_CONSTRAINT"]
  254. ):
  255. gtRectMat[gtNum] = 1
  256. detRectMat[detNum] = 1
  257. detMatched += 1
  258. pairs.append({"gt": gtNum, "det": detNum})
  259. detMatchedNums.append(detNum)
  260. evaluationLog += (
  261. "Match GT #"
  262. + str(gtNum)
  263. + " with Det #"
  264. + str(detNum)
  265. + "\n"
  266. )
  267. if evaluationParams["CONFIDENCES"]:
  268. for detNum in range(len(detPols)):
  269. if detNum not in detDontCarePolsNum:
  270. # we exclude the don't care detections
  271. match = detNum in detMatchedNums
  272. arrSampleConfidences.append(confidencesList[detNum])
  273. arrSampleMatch.append(match)
  274. arrGlobalConfidences.append(confidencesList[detNum])
  275. arrGlobalMatches.append(match)
  276. numGtCare = len(gtPols) - len(gtDontCarePolsNum)
  277. numDetCare = len(detPols) - len(detDontCarePolsNum)
  278. if numGtCare == 0:
  279. recall = float(1)
  280. precision = float(0) if numDetCare > 0 else float(1)
  281. sampleAP = precision
  282. else:
  283. recall = float(detMatched) / numGtCare
  284. precision = 0 if numDetCare == 0 else float(detMatched) / numDetCare
  285. if (
  286. evaluationParams["CONFIDENCES"]
  287. and evaluationParams["PER_SAMPLE_RESULTS"]
  288. ):
  289. sampleAP = compute_ap(arrSampleConfidences, arrSampleMatch, numGtCare)
  290. hmean = (
  291. 0
  292. if (precision + recall) == 0
  293. else 2.0 * precision * recall / (precision + recall)
  294. )
  295. matchedSum += detMatched
  296. numGlobalCareGt += numGtCare
  297. numGlobalCareDet += numDetCare
  298. if evaluationParams["PER_SAMPLE_RESULTS"]:
  299. perSampleMetrics[resFile] = {
  300. "precision": precision,
  301. "recall": recall,
  302. "hmean": hmean,
  303. "pairs": pairs,
  304. "AP": sampleAP,
  305. "iouMat": [] if len(detPols) > 100 else iouMat.tolist(),
  306. "gtPolPoints": gtPolPoints,
  307. "detPolPoints": detPolPoints,
  308. "gtDontCare": gtDontCarePolsNum,
  309. "detDontCare": detDontCarePolsNum,
  310. "evaluationParams": evaluationParams,
  311. "evaluationLog": evaluationLog,
  312. }
  313. # Compute MAP and MAR
  314. AP = 0
  315. if evaluationParams["CONFIDENCES"]:
  316. AP = compute_ap(arrGlobalConfidences, arrGlobalMatches, numGlobalCareGt)
  317. methodRecall = 0 if numGlobalCareGt == 0 else float(matchedSum) / numGlobalCareGt
  318. methodPrecision = (
  319. 0 if numGlobalCareDet == 0 else float(matchedSum) / numGlobalCareDet
  320. )
  321. methodHmean = (
  322. 0
  323. if methodRecall + methodPrecision == 0
  324. else 2 * methodRecall * methodPrecision / (methodRecall + methodPrecision)
  325. )
  326. methodMetrics = {
  327. "precision": methodPrecision,
  328. "recall": methodRecall,
  329. "hmean": methodHmean,
  330. "AP": AP,
  331. }
  332. resDict = {
  333. "calculated": True,
  334. "Message": "",
  335. "method": methodMetrics,
  336. "per_sample": perSampleMetrics,
  337. }
  338. return resDict
  339. def cal_recall_precision_f1(gt_path, result_path, show_result=False):
  340. p = {"g": gt_path, "s": result_path}
  341. result = rrc_evaluation_funcs.main_evaluation(
  342. p, default_evaluation_params, validate_data, evaluate_method, show_result
  343. )
  344. return result["method"]