Deteval.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import json
  15. import numpy as np
  16. import scipy.io as io
  17. from ppocr.utils.utility import check_install
  18. from ppocr.utils.e2e_metric.polygon_fast import iod, area_of_intersection, area
  19. def get_socre_A(gt_dir, pred_dict):
  20. allInputs = 1
  21. def input_reading_mod(pred_dict):
  22. """This helper reads input from txt files"""
  23. det = []
  24. n = len(pred_dict)
  25. for i in range(n):
  26. points = pred_dict[i]["points"]
  27. text = pred_dict[i]["texts"]
  28. point = ",".join(
  29. map(
  30. str,
  31. points.reshape(
  32. -1,
  33. ),
  34. )
  35. )
  36. det.append([point, text])
  37. return det
  38. def gt_reading_mod(gt_dict):
  39. """This helper reads groundtruths from mat files"""
  40. gt = []
  41. n = len(gt_dict)
  42. for i in range(n):
  43. points = gt_dict[i]["points"].tolist()
  44. h = len(points)
  45. text = gt_dict[i]["text"]
  46. xx = [
  47. np.array(["x:"], dtype="<U2"),
  48. 0,
  49. np.array(["y:"], dtype="<U2"),
  50. 0,
  51. np.array(["#"], dtype="<U1"),
  52. np.array(["#"], dtype="<U1"),
  53. ]
  54. t_x, t_y = [], []
  55. for j in range(h):
  56. t_x.append(points[j][0])
  57. t_y.append(points[j][1])
  58. xx[1] = np.array([t_x], dtype="int16")
  59. xx[3] = np.array([t_y], dtype="int16")
  60. if text != "":
  61. xx[4] = np.array([text], dtype="U{}".format(len(text)))
  62. xx[5] = np.array(["c"], dtype="<U1")
  63. gt.append(xx)
  64. return gt
  65. def detection_filtering(detections, groundtruths, threshold=0.5):
  66. for gt_id, gt in enumerate(groundtruths):
  67. if (gt[5] == "#") and (gt[1].shape[1] > 1):
  68. gt_x = list(map(int, np.squeeze(gt[1])))
  69. gt_y = list(map(int, np.squeeze(gt[3])))
  70. for det_id, detection in enumerate(detections):
  71. detection_orig = detection
  72. detection = [float(x) for x in detection[0].split(",")]
  73. detection = list(map(int, detection))
  74. det_x = detection[0::2]
  75. det_y = detection[1::2]
  76. det_gt_iou = iod(det_x, det_y, gt_x, gt_y)
  77. if det_gt_iou > threshold:
  78. detections[det_id] = []
  79. detections[:] = [item for item in detections if item != []]
  80. return detections
  81. def sigma_calculation(det_x, det_y, gt_x, gt_y):
  82. """
  83. sigma = inter_area / gt_area
  84. """
  85. return np.round(
  86. (area_of_intersection(det_x, det_y, gt_x, gt_y) / area(gt_x, gt_y)), 2
  87. )
  88. def tau_calculation(det_x, det_y, gt_x, gt_y):
  89. if area(det_x, det_y) == 0.0:
  90. return 0
  91. return np.round(
  92. (area_of_intersection(det_x, det_y, gt_x, gt_y) / area(det_x, det_y)), 2
  93. )
  94. ##############################Initialization###################################
  95. # global_sigma = []
  96. # global_tau = []
  97. # global_pred_str = []
  98. # global_gt_str = []
  99. ###############################################################################
  100. for input_id in range(allInputs):
  101. if (
  102. (input_id != ".DS_Store")
  103. and (input_id != "Pascal_result.txt")
  104. and (input_id != "Pascal_result_curved.txt")
  105. and (input_id != "Pascal_result_non_curved.txt")
  106. and (input_id != "Deteval_result.txt")
  107. and (input_id != "Deteval_result_curved.txt")
  108. and (input_id != "Deteval_result_non_curved.txt")
  109. ):
  110. detections = input_reading_mod(pred_dict)
  111. groundtruths = gt_reading_mod(gt_dir)
  112. detections = detection_filtering(
  113. detections, groundtruths
  114. ) # filters detections overlapping with DC area
  115. dc_id = []
  116. for i in range(len(groundtruths)):
  117. if groundtruths[i][5] == "#":
  118. dc_id.append(i)
  119. cnt = 0
  120. for a in dc_id:
  121. num = a - cnt
  122. del groundtruths[num]
  123. cnt += 1
  124. local_sigma_table = np.zeros((len(groundtruths), len(detections)))
  125. local_tau_table = np.zeros((len(groundtruths), len(detections)))
  126. local_pred_str = {}
  127. local_gt_str = {}
  128. for gt_id, gt in enumerate(groundtruths):
  129. if len(detections) > 0:
  130. for det_id, detection in enumerate(detections):
  131. detection_orig = detection
  132. detection = [float(x) for x in detection[0].split(",")]
  133. detection = list(map(int, detection))
  134. pred_seq_str = detection_orig[1].strip()
  135. det_x = detection[0::2]
  136. det_y = detection[1::2]
  137. gt_x = list(map(int, np.squeeze(gt[1])))
  138. gt_y = list(map(int, np.squeeze(gt[3])))
  139. gt_seq_str = str(gt[4].tolist()[0])
  140. local_sigma_table[gt_id, det_id] = sigma_calculation(
  141. det_x, det_y, gt_x, gt_y
  142. )
  143. local_tau_table[gt_id, det_id] = tau_calculation(
  144. det_x, det_y, gt_x, gt_y
  145. )
  146. local_pred_str[det_id] = pred_seq_str
  147. local_gt_str[gt_id] = gt_seq_str
  148. global_sigma = local_sigma_table
  149. global_tau = local_tau_table
  150. global_pred_str = local_pred_str
  151. global_gt_str = local_gt_str
  152. single_data = {}
  153. single_data["sigma"] = global_sigma
  154. single_data["global_tau"] = global_tau
  155. single_data["global_pred_str"] = global_pred_str
  156. single_data["global_gt_str"] = global_gt_str
  157. return single_data
  158. def get_socre_B(gt_dir, img_id, pred_dict):
  159. allInputs = 1
  160. def input_reading_mod(pred_dict):
  161. """This helper reads input from txt files"""
  162. det = []
  163. n = len(pred_dict)
  164. for i in range(n):
  165. points = pred_dict[i]["points"]
  166. text = pred_dict[i]["texts"]
  167. point = ",".join(
  168. map(
  169. str,
  170. points.reshape(
  171. -1,
  172. ),
  173. )
  174. )
  175. det.append([point, text])
  176. return det
  177. def gt_reading_mod(gt_dir, gt_id):
  178. gt = io.loadmat("%s/poly_gt_img%s.mat" % (gt_dir, gt_id))
  179. gt = gt["polygt"]
  180. return gt
  181. def detection_filtering(detections, groundtruths, threshold=0.5):
  182. for gt_id, gt in enumerate(groundtruths):
  183. if (gt[5] == "#") and (gt[1].shape[1] > 1):
  184. gt_x = list(map(int, np.squeeze(gt[1])))
  185. gt_y = list(map(int, np.squeeze(gt[3])))
  186. for det_id, detection in enumerate(detections):
  187. detection_orig = detection
  188. detection = [float(x) for x in detection[0].split(",")]
  189. detection = list(map(int, detection))
  190. det_x = detection[0::2]
  191. det_y = detection[1::2]
  192. det_gt_iou = iod(det_x, det_y, gt_x, gt_y)
  193. if det_gt_iou > threshold:
  194. detections[det_id] = []
  195. detections[:] = [item for item in detections if item != []]
  196. return detections
  197. def sigma_calculation(det_x, det_y, gt_x, gt_y):
  198. """
  199. sigma = inter_area / gt_area
  200. """
  201. return np.round(
  202. (area_of_intersection(det_x, det_y, gt_x, gt_y) / area(gt_x, gt_y)), 2
  203. )
  204. def tau_calculation(det_x, det_y, gt_x, gt_y):
  205. if area(det_x, det_y) == 0.0:
  206. return 0
  207. return np.round(
  208. (area_of_intersection(det_x, det_y, gt_x, gt_y) / area(det_x, det_y)), 2
  209. )
  210. ##############################Initialization###################################
  211. # global_sigma = []
  212. # global_tau = []
  213. # global_pred_str = []
  214. # global_gt_str = []
  215. ###############################################################################
  216. for input_id in range(allInputs):
  217. if (
  218. (input_id != ".DS_Store")
  219. and (input_id != "Pascal_result.txt")
  220. and (input_id != "Pascal_result_curved.txt")
  221. and (input_id != "Pascal_result_non_curved.txt")
  222. and (input_id != "Deteval_result.txt")
  223. and (input_id != "Deteval_result_curved.txt")
  224. and (input_id != "Deteval_result_non_curved.txt")
  225. ):
  226. detections = input_reading_mod(pred_dict)
  227. groundtruths = gt_reading_mod(gt_dir, img_id).tolist()
  228. detections = detection_filtering(
  229. detections, groundtruths
  230. ) # filters detections overlapping with DC area
  231. dc_id = []
  232. for i in range(len(groundtruths)):
  233. if groundtruths[i][5] == "#":
  234. dc_id.append(i)
  235. cnt = 0
  236. for a in dc_id:
  237. num = a - cnt
  238. del groundtruths[num]
  239. cnt += 1
  240. local_sigma_table = np.zeros((len(groundtruths), len(detections)))
  241. local_tau_table = np.zeros((len(groundtruths), len(detections)))
  242. local_pred_str = {}
  243. local_gt_str = {}
  244. for gt_id, gt in enumerate(groundtruths):
  245. if len(detections) > 0:
  246. for det_id, detection in enumerate(detections):
  247. detection_orig = detection
  248. detection = [float(x) for x in detection[0].split(",")]
  249. detection = list(map(int, detection))
  250. pred_seq_str = detection_orig[1].strip()
  251. det_x = detection[0::2]
  252. det_y = detection[1::2]
  253. gt_x = list(map(int, np.squeeze(gt[1])))
  254. gt_y = list(map(int, np.squeeze(gt[3])))
  255. gt_seq_str = str(gt[4].tolist()[0])
  256. local_sigma_table[gt_id, det_id] = sigma_calculation(
  257. det_x, det_y, gt_x, gt_y
  258. )
  259. local_tau_table[gt_id, det_id] = tau_calculation(
  260. det_x, det_y, gt_x, gt_y
  261. )
  262. local_pred_str[det_id] = pred_seq_str
  263. local_gt_str[gt_id] = gt_seq_str
  264. global_sigma = local_sigma_table
  265. global_tau = local_tau_table
  266. global_pred_str = local_pred_str
  267. global_gt_str = local_gt_str
  268. single_data = {}
  269. single_data["sigma"] = global_sigma
  270. single_data["global_tau"] = global_tau
  271. single_data["global_pred_str"] = global_pred_str
  272. single_data["global_gt_str"] = global_gt_str
  273. return single_data
  274. def get_score_C(gt_label, text, pred_bboxes):
  275. """
  276. get score for CentripetalText (CT) prediction.
  277. """
  278. check_install("Polygon", "Polygon3")
  279. import Polygon as plg
  280. def gt_reading_mod(gt_label, text):
  281. """This helper reads groundtruths from mat files"""
  282. groundtruths = []
  283. nbox = len(gt_label)
  284. for i in range(nbox):
  285. label = {"transcription": text[i][0], "points": gt_label[i].numpy()}
  286. groundtruths.append(label)
  287. return groundtruths
  288. def get_union(pD, pG):
  289. areaA = pD.area()
  290. areaB = pG.area()
  291. return areaA + areaB - get_intersection(pD, pG)
  292. def get_intersection(pD, pG):
  293. pInt = pD & pG
  294. if len(pInt) == 0:
  295. return 0
  296. return pInt.area()
  297. def detection_filtering(detections, groundtruths, threshold=0.5):
  298. for gt in groundtruths:
  299. point_num = gt["points"].shape[1] // 2
  300. if gt["transcription"] == "###" and (point_num > 1):
  301. gt_p = np.array(gt["points"]).reshape(point_num, 2).astype("int32")
  302. gt_p = plg.Polygon(gt_p)
  303. for det_id, detection in enumerate(detections):
  304. det_y = detection[0::2]
  305. det_x = detection[1::2]
  306. det_p = np.concatenate((np.array(det_x), np.array(det_y)))
  307. det_p = det_p.reshape(2, -1).transpose()
  308. det_p = plg.Polygon(det_p)
  309. try:
  310. det_gt_iou = get_intersection(det_p, gt_p) / det_p.area()
  311. except:
  312. print(det_x, det_y, gt_p)
  313. if det_gt_iou > threshold:
  314. detections[det_id] = []
  315. detections[:] = [item for item in detections if item != []]
  316. return detections
  317. def sigma_calculation(det_p, gt_p):
  318. """
  319. sigma = inter_area / gt_area
  320. """
  321. if gt_p.area() == 0.0:
  322. return 0
  323. return get_intersection(det_p, gt_p) / gt_p.area()
  324. def tau_calculation(det_p, gt_p):
  325. """
  326. tau = inter_area / det_area
  327. """
  328. if det_p.area() == 0.0:
  329. return 0
  330. return get_intersection(det_p, gt_p) / det_p.area()
  331. detections = []
  332. for item in pred_bboxes:
  333. detections.append(item[:, ::-1].reshape(-1))
  334. groundtruths = gt_reading_mod(gt_label, text)
  335. detections = detection_filtering(
  336. detections, groundtruths
  337. ) # filters detections overlapping with DC area
  338. for idx in range(len(groundtruths) - 1, -1, -1):
  339. # NOTE: source code use 'orin' to indicate '#', here we use 'anno',
  340. # which may cause slight drop in fscore, about 0.12
  341. if groundtruths[idx]["transcription"] == "###":
  342. groundtruths.pop(idx)
  343. local_sigma_table = np.zeros((len(groundtruths), len(detections)))
  344. local_tau_table = np.zeros((len(groundtruths), len(detections)))
  345. for gt_id, gt in enumerate(groundtruths):
  346. if len(detections) > 0:
  347. for det_id, detection in enumerate(detections):
  348. point_num = gt["points"].shape[1] // 2
  349. gt_p = np.array(gt["points"]).reshape(point_num, 2).astype("int32")
  350. gt_p = plg.Polygon(gt_p)
  351. det_y = detection[0::2]
  352. det_x = detection[1::2]
  353. det_p = np.concatenate((np.array(det_x), np.array(det_y)))
  354. det_p = det_p.reshape(2, -1).transpose()
  355. det_p = plg.Polygon(det_p)
  356. local_sigma_table[gt_id, det_id] = sigma_calculation(det_p, gt_p)
  357. local_tau_table[gt_id, det_id] = tau_calculation(det_p, gt_p)
  358. data = {}
  359. data["sigma"] = local_sigma_table
  360. data["global_tau"] = local_tau_table
  361. data["global_pred_str"] = ""
  362. data["global_gt_str"] = ""
  363. return data
  364. def combine_results(all_data, rec_flag=True):
  365. tr = 0.7
  366. tp = 0.6
  367. fsc_k = 0.8
  368. k = 2
  369. global_sigma = []
  370. global_tau = []
  371. global_pred_str = []
  372. global_gt_str = []
  373. for data in all_data:
  374. global_sigma.append(data["sigma"])
  375. global_tau.append(data["global_tau"])
  376. global_pred_str.append(data["global_pred_str"])
  377. global_gt_str.append(data["global_gt_str"])
  378. global_accumulative_recall = 0
  379. global_accumulative_precision = 0
  380. total_num_gt = 0
  381. total_num_det = 0
  382. hit_str_count = 0
  383. hit_count = 0
  384. def one_to_one(
  385. local_sigma_table,
  386. local_tau_table,
  387. local_accumulative_recall,
  388. local_accumulative_precision,
  389. global_accumulative_recall,
  390. global_accumulative_precision,
  391. gt_flag,
  392. det_flag,
  393. idy,
  394. rec_flag,
  395. ):
  396. hit_str_num = 0
  397. for gt_id in range(num_gt):
  398. gt_matching_qualified_sigma_candidates = np.where(
  399. local_sigma_table[gt_id, :] > tr
  400. )
  401. gt_matching_num_qualified_sigma_candidates = (
  402. gt_matching_qualified_sigma_candidates[0].shape[0]
  403. )
  404. gt_matching_qualified_tau_candidates = np.where(
  405. local_tau_table[gt_id, :] > tp
  406. )
  407. gt_matching_num_qualified_tau_candidates = (
  408. gt_matching_qualified_tau_candidates[0].shape[0]
  409. )
  410. det_matching_qualified_sigma_candidates = np.where(
  411. local_sigma_table[:, gt_matching_qualified_sigma_candidates[0]] > tr
  412. )
  413. det_matching_num_qualified_sigma_candidates = (
  414. det_matching_qualified_sigma_candidates[0].shape[0]
  415. )
  416. det_matching_qualified_tau_candidates = np.where(
  417. local_tau_table[:, gt_matching_qualified_tau_candidates[0]] > tp
  418. )
  419. det_matching_num_qualified_tau_candidates = (
  420. det_matching_qualified_tau_candidates[0].shape[0]
  421. )
  422. if (
  423. (gt_matching_num_qualified_sigma_candidates == 1)
  424. and (gt_matching_num_qualified_tau_candidates == 1)
  425. and (det_matching_num_qualified_sigma_candidates == 1)
  426. and (det_matching_num_qualified_tau_candidates == 1)
  427. ):
  428. global_accumulative_recall = global_accumulative_recall + 1.0
  429. global_accumulative_precision = global_accumulative_precision + 1.0
  430. local_accumulative_recall = local_accumulative_recall + 1.0
  431. local_accumulative_precision = local_accumulative_precision + 1.0
  432. gt_flag[0, gt_id] = 1
  433. matched_det_id = np.where(local_sigma_table[gt_id, :] > tr)
  434. # recg start
  435. if rec_flag:
  436. gt_str_cur = global_gt_str[idy][gt_id]
  437. pred_str_cur = global_pred_str[idy][matched_det_id[0].tolist()[0]]
  438. if pred_str_cur == gt_str_cur:
  439. hit_str_num += 1
  440. else:
  441. if pred_str_cur.lower() == gt_str_cur.lower():
  442. hit_str_num += 1
  443. # recg end
  444. det_flag[0, matched_det_id] = 1
  445. return (
  446. local_accumulative_recall,
  447. local_accumulative_precision,
  448. global_accumulative_recall,
  449. global_accumulative_precision,
  450. gt_flag,
  451. det_flag,
  452. hit_str_num,
  453. )
  454. def one_to_many(
  455. local_sigma_table,
  456. local_tau_table,
  457. local_accumulative_recall,
  458. local_accumulative_precision,
  459. global_accumulative_recall,
  460. global_accumulative_precision,
  461. gt_flag,
  462. det_flag,
  463. idy,
  464. rec_flag,
  465. ):
  466. hit_str_num = 0
  467. for gt_id in range(num_gt):
  468. # skip the following if the groundtruth was matched
  469. if gt_flag[0, gt_id] > 0:
  470. continue
  471. non_zero_in_sigma = np.where(local_sigma_table[gt_id, :] > 0)
  472. num_non_zero_in_sigma = non_zero_in_sigma[0].shape[0]
  473. if num_non_zero_in_sigma >= k:
  474. ####search for all detections that overlaps with this groundtruth
  475. qualified_tau_candidates = np.where(
  476. (local_tau_table[gt_id, :] >= tp) & (det_flag[0, :] == 0)
  477. )
  478. num_qualified_tau_candidates = qualified_tau_candidates[0].shape[0]
  479. if num_qualified_tau_candidates == 1:
  480. if (local_tau_table[gt_id, qualified_tau_candidates] >= tp) and (
  481. local_sigma_table[gt_id, qualified_tau_candidates] >= tr
  482. ):
  483. # became an one-to-one case
  484. global_accumulative_recall = global_accumulative_recall + 1.0
  485. global_accumulative_precision = (
  486. global_accumulative_precision + 1.0
  487. )
  488. local_accumulative_recall = local_accumulative_recall + 1.0
  489. local_accumulative_precision = (
  490. local_accumulative_precision + 1.0
  491. )
  492. gt_flag[0, gt_id] = 1
  493. det_flag[0, qualified_tau_candidates] = 1
  494. # recg start
  495. if rec_flag:
  496. gt_str_cur = global_gt_str[idy][gt_id]
  497. pred_str_cur = global_pred_str[idy][
  498. qualified_tau_candidates[0].tolist()[0]
  499. ]
  500. if pred_str_cur == gt_str_cur:
  501. hit_str_num += 1
  502. else:
  503. if pred_str_cur.lower() == gt_str_cur.lower():
  504. hit_str_num += 1
  505. # recg end
  506. elif np.sum(local_sigma_table[gt_id, qualified_tau_candidates]) >= tr:
  507. gt_flag[0, gt_id] = 1
  508. det_flag[0, qualified_tau_candidates] = 1
  509. # recg start
  510. if rec_flag:
  511. gt_str_cur = global_gt_str[idy][gt_id]
  512. pred_str_cur = global_pred_str[idy][
  513. qualified_tau_candidates[0].tolist()[0]
  514. ]
  515. if pred_str_cur == gt_str_cur:
  516. hit_str_num += 1
  517. else:
  518. if pred_str_cur.lower() == gt_str_cur.lower():
  519. hit_str_num += 1
  520. # recg end
  521. global_accumulative_recall = global_accumulative_recall + fsc_k
  522. global_accumulative_precision = (
  523. global_accumulative_precision
  524. + num_qualified_tau_candidates * fsc_k
  525. )
  526. local_accumulative_recall = local_accumulative_recall + fsc_k
  527. local_accumulative_precision = (
  528. local_accumulative_precision
  529. + num_qualified_tau_candidates * fsc_k
  530. )
  531. return (
  532. local_accumulative_recall,
  533. local_accumulative_precision,
  534. global_accumulative_recall,
  535. global_accumulative_precision,
  536. gt_flag,
  537. det_flag,
  538. hit_str_num,
  539. )
  540. def many_to_one(
  541. local_sigma_table,
  542. local_tau_table,
  543. local_accumulative_recall,
  544. local_accumulative_precision,
  545. global_accumulative_recall,
  546. global_accumulative_precision,
  547. gt_flag,
  548. det_flag,
  549. idy,
  550. rec_flag,
  551. ):
  552. hit_str_num = 0
  553. for det_id in range(num_det):
  554. # skip the following if the detection was matched
  555. if det_flag[0, det_id] > 0:
  556. continue
  557. non_zero_in_tau = np.where(local_tau_table[:, det_id] > 0)
  558. num_non_zero_in_tau = non_zero_in_tau[0].shape[0]
  559. if num_non_zero_in_tau >= k:
  560. ####search for all detections that overlaps with this groundtruth
  561. qualified_sigma_candidates = np.where(
  562. (local_sigma_table[:, det_id] >= tp) & (gt_flag[0, :] == 0)
  563. )
  564. num_qualified_sigma_candidates = qualified_sigma_candidates[0].shape[0]
  565. if num_qualified_sigma_candidates == 1:
  566. if (local_tau_table[qualified_sigma_candidates, det_id] >= tp) and (
  567. local_sigma_table[qualified_sigma_candidates, det_id] >= tr
  568. ):
  569. # became an one-to-one case
  570. global_accumulative_recall = global_accumulative_recall + 1.0
  571. global_accumulative_precision = (
  572. global_accumulative_precision + 1.0
  573. )
  574. local_accumulative_recall = local_accumulative_recall + 1.0
  575. local_accumulative_precision = (
  576. local_accumulative_precision + 1.0
  577. )
  578. gt_flag[0, qualified_sigma_candidates] = 1
  579. det_flag[0, det_id] = 1
  580. # recg start
  581. if rec_flag:
  582. pred_str_cur = global_pred_str[idy][det_id]
  583. gt_len = len(qualified_sigma_candidates[0])
  584. for idx in range(gt_len):
  585. ele_gt_id = qualified_sigma_candidates[0].tolist()[idx]
  586. if ele_gt_id not in global_gt_str[idy]:
  587. continue
  588. gt_str_cur = global_gt_str[idy][ele_gt_id]
  589. if pred_str_cur == gt_str_cur:
  590. hit_str_num += 1
  591. break
  592. else:
  593. if pred_str_cur.lower() == gt_str_cur.lower():
  594. hit_str_num += 1
  595. break
  596. # recg end
  597. elif np.sum(local_tau_table[qualified_sigma_candidates, det_id]) >= tp:
  598. det_flag[0, det_id] = 1
  599. gt_flag[0, qualified_sigma_candidates] = 1
  600. # recg start
  601. if rec_flag:
  602. pred_str_cur = global_pred_str[idy][det_id]
  603. gt_len = len(qualified_sigma_candidates[0])
  604. for idx in range(gt_len):
  605. ele_gt_id = qualified_sigma_candidates[0].tolist()[idx]
  606. if ele_gt_id not in global_gt_str[idy]:
  607. continue
  608. gt_str_cur = global_gt_str[idy][ele_gt_id]
  609. if pred_str_cur == gt_str_cur:
  610. hit_str_num += 1
  611. break
  612. else:
  613. if pred_str_cur.lower() == gt_str_cur.lower():
  614. hit_str_num += 1
  615. break
  616. # recg end
  617. global_accumulative_recall = (
  618. global_accumulative_recall
  619. + num_qualified_sigma_candidates * fsc_k
  620. )
  621. global_accumulative_precision = (
  622. global_accumulative_precision + fsc_k
  623. )
  624. local_accumulative_recall = (
  625. local_accumulative_recall
  626. + num_qualified_sigma_candidates * fsc_k
  627. )
  628. local_accumulative_precision = local_accumulative_precision + fsc_k
  629. return (
  630. local_accumulative_recall,
  631. local_accumulative_precision,
  632. global_accumulative_recall,
  633. global_accumulative_precision,
  634. gt_flag,
  635. det_flag,
  636. hit_str_num,
  637. )
  638. for idx in range(len(global_sigma)):
  639. local_sigma_table = np.array(global_sigma[idx])
  640. local_tau_table = global_tau[idx]
  641. num_gt = local_sigma_table.shape[0]
  642. num_det = local_sigma_table.shape[1]
  643. total_num_gt = total_num_gt + num_gt
  644. total_num_det = total_num_det + num_det
  645. local_accumulative_recall = 0
  646. local_accumulative_precision = 0
  647. gt_flag = np.zeros((1, num_gt))
  648. det_flag = np.zeros((1, num_det))
  649. #######first check for one-to-one case##########
  650. (
  651. local_accumulative_recall,
  652. local_accumulative_precision,
  653. global_accumulative_recall,
  654. global_accumulative_precision,
  655. gt_flag,
  656. det_flag,
  657. hit_str_num,
  658. ) = one_to_one(
  659. local_sigma_table,
  660. local_tau_table,
  661. local_accumulative_recall,
  662. local_accumulative_precision,
  663. global_accumulative_recall,
  664. global_accumulative_precision,
  665. gt_flag,
  666. det_flag,
  667. idx,
  668. rec_flag,
  669. )
  670. hit_str_count += hit_str_num
  671. #######then check for one-to-many case##########
  672. (
  673. local_accumulative_recall,
  674. local_accumulative_precision,
  675. global_accumulative_recall,
  676. global_accumulative_precision,
  677. gt_flag,
  678. det_flag,
  679. hit_str_num,
  680. ) = one_to_many(
  681. local_sigma_table,
  682. local_tau_table,
  683. local_accumulative_recall,
  684. local_accumulative_precision,
  685. global_accumulative_recall,
  686. global_accumulative_precision,
  687. gt_flag,
  688. det_flag,
  689. idx,
  690. rec_flag,
  691. )
  692. hit_str_count += hit_str_num
  693. #######then check for many-to-one case##########
  694. (
  695. local_accumulative_recall,
  696. local_accumulative_precision,
  697. global_accumulative_recall,
  698. global_accumulative_precision,
  699. gt_flag,
  700. det_flag,
  701. hit_str_num,
  702. ) = many_to_one(
  703. local_sigma_table,
  704. local_tau_table,
  705. local_accumulative_recall,
  706. local_accumulative_precision,
  707. global_accumulative_recall,
  708. global_accumulative_precision,
  709. gt_flag,
  710. det_flag,
  711. idx,
  712. rec_flag,
  713. )
  714. hit_str_count += hit_str_num
  715. try:
  716. recall = global_accumulative_recall / total_num_gt
  717. except ZeroDivisionError:
  718. recall = 0
  719. try:
  720. precision = global_accumulative_precision / total_num_det
  721. except ZeroDivisionError:
  722. precision = 0
  723. try:
  724. f_score = 2 * precision * recall / (precision + recall)
  725. except ZeroDivisionError:
  726. f_score = 0
  727. try:
  728. seqerr = 1 - float(hit_str_count) / global_accumulative_recall
  729. except ZeroDivisionError:
  730. seqerr = 1
  731. try:
  732. recall_e2e = float(hit_str_count) / total_num_gt
  733. except ZeroDivisionError:
  734. recall_e2e = 0
  735. try:
  736. precision_e2e = float(hit_str_count) / total_num_det
  737. except ZeroDivisionError:
  738. precision_e2e = 0
  739. try:
  740. f_score_e2e = 2 * precision_e2e * recall_e2e / (precision_e2e + recall_e2e)
  741. except ZeroDivisionError:
  742. f_score_e2e = 0
  743. final = {
  744. "total_num_gt": total_num_gt,
  745. "total_num_det": total_num_det,
  746. "global_accumulative_recall": global_accumulative_recall,
  747. "hit_str_count": hit_str_count,
  748. "recall": recall,
  749. "precision": precision,
  750. "f_score": f_score,
  751. "seqerr": seqerr,
  752. "recall_e2e": recall_e2e,
  753. "precision_e2e": precision_e2e,
  754. "f_score_e2e": f_score_e2e,
  755. }
  756. return final