label_ops.py 73 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. from __future__ import unicode_literals
  18. import os
  19. from enum import Enum
  20. import copy
  21. import numpy as np
  22. import string
  23. from shapely.geometry import LineString, Point, Polygon
  24. import json
  25. import copy
  26. import random
  27. from random import sample
  28. from collections import defaultdict
  29. from ppocr.utils.logging import get_logger
  30. from ppocr.data.imaug.vqa.augment import order_by_tbyx
  31. class ClsLabelEncode(object):
  32. def __init__(self, label_list, **kwargs):
  33. self.label_list = label_list
  34. def __call__(self, data):
  35. label = data["label"]
  36. if label not in self.label_list:
  37. return None
  38. label = self.label_list.index(label)
  39. data["label"] = label
  40. return data
  41. class DetLabelEncode(object):
  42. def __init__(self, **kwargs):
  43. pass
  44. def __call__(self, data):
  45. label = data["label"]
  46. label = json.loads(label)
  47. nBox = len(label)
  48. boxes, txts, txt_tags = [], [], []
  49. for bno in range(0, nBox):
  50. box = label[bno]["points"]
  51. txt = label[bno]["transcription"]
  52. boxes.append(box)
  53. txts.append(txt)
  54. if txt in ["*", "###"]:
  55. txt_tags.append(True)
  56. else:
  57. txt_tags.append(False)
  58. if len(boxes) == 0:
  59. return None
  60. boxes = self.expand_points_num(boxes)
  61. boxes = np.array(boxes, dtype=np.float32)
  62. txt_tags = np.array(txt_tags, dtype=np.bool_)
  63. data["polys"] = boxes
  64. data["texts"] = txts
  65. data["ignore_tags"] = txt_tags
  66. return data
  67. def order_points_clockwise(self, pts):
  68. rect = np.zeros((4, 2), dtype="float32")
  69. s = pts.sum(axis=1)
  70. rect[0] = pts[np.argmin(s)]
  71. rect[2] = pts[np.argmax(s)]
  72. tmp = np.delete(pts, (np.argmin(s), np.argmax(s)), axis=0)
  73. diff = np.diff(np.array(tmp), axis=1)
  74. rect[1] = tmp[np.argmin(diff)]
  75. rect[3] = tmp[np.argmax(diff)]
  76. return rect
  77. def expand_points_num(self, boxes):
  78. max_points_num = 0
  79. for box in boxes:
  80. if len(box) > max_points_num:
  81. max_points_num = len(box)
  82. ex_boxes = []
  83. for box in boxes:
  84. ex_box = box + [box[-1]] * (max_points_num - len(box))
  85. ex_boxes.append(ex_box)
  86. return ex_boxes
  87. class BaseRecLabelEncode(object):
  88. """Convert between text-label and text-index"""
  89. def __init__(
  90. self,
  91. max_text_length,
  92. character_dict_path=None,
  93. use_space_char=False,
  94. lower=False,
  95. ):
  96. self.max_text_len = max_text_length
  97. self.beg_str = "sos"
  98. self.end_str = "eos"
  99. self.lower = lower
  100. if character_dict_path is None:
  101. logger = get_logger()
  102. logger.warning(
  103. "The character_dict_path is None, model can only recognize number and lower letters"
  104. )
  105. self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
  106. dict_character = list(self.character_str)
  107. self.lower = True
  108. else:
  109. self.character_str = []
  110. with open(character_dict_path, "rb") as fin:
  111. lines = fin.readlines()
  112. for line in lines:
  113. line = line.decode("utf-8").strip("\n").strip("\r\n")
  114. self.character_str.append(line)
  115. if use_space_char:
  116. self.character_str.append(" ")
  117. dict_character = list(self.character_str)
  118. dict_character = self.add_special_char(dict_character)
  119. self.dict = {}
  120. for i, char in enumerate(dict_character):
  121. self.dict[char] = i
  122. self.character = dict_character
  123. def add_special_char(self, dict_character):
  124. return dict_character
  125. def encode(self, text):
  126. """convert text-label into text-index.
  127. input:
  128. text: text labels of each image. [batch_size]
  129. output:
  130. text: concatenated text index for CTCLoss.
  131. [sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
  132. length: length of each text. [batch_size]
  133. """
  134. if len(text) == 0 or len(text) > self.max_text_len:
  135. return None
  136. if self.lower:
  137. text = text.lower()
  138. text_list = []
  139. for char in text:
  140. if char not in self.dict:
  141. # logger = get_logger()
  142. # logger.warning('{} is not in dict'.format(char))
  143. continue
  144. text_list.append(self.dict[char])
  145. if len(text_list) == 0:
  146. return None
  147. return text_list
  148. class CTCLabelEncode(BaseRecLabelEncode):
  149. """Convert between text-label and text-index"""
  150. def __init__(
  151. self, max_text_length, character_dict_path=None, use_space_char=False, **kwargs
  152. ):
  153. super(CTCLabelEncode, self).__init__(
  154. max_text_length, character_dict_path, use_space_char
  155. )
  156. def __call__(self, data):
  157. text = data["label"]
  158. text = self.encode(text)
  159. if text is None:
  160. return None
  161. data["length"] = np.array(len(text))
  162. text = text + [0] * (self.max_text_len - len(text))
  163. data["label"] = np.array(text)
  164. label = [0] * len(self.character)
  165. for x in text:
  166. label[x] += 1
  167. data["label_ace"] = np.array(label)
  168. return data
  169. def add_special_char(self, dict_character):
  170. dict_character = ["blank"] + dict_character
  171. return dict_character
  172. class E2ELabelEncodeTest(BaseRecLabelEncode):
  173. def __init__(
  174. self, max_text_length, character_dict_path=None, use_space_char=False, **kwargs
  175. ):
  176. super(E2ELabelEncodeTest, self).__init__(
  177. max_text_length, character_dict_path, use_space_char
  178. )
  179. def __call__(self, data):
  180. import json
  181. padnum = len(self.dict)
  182. label = data["label"]
  183. label = json.loads(label)
  184. nBox = len(label)
  185. boxes, txts, txt_tags = [], [], []
  186. for bno in range(0, nBox):
  187. box = label[bno]["points"]
  188. txt = label[bno]["transcription"]
  189. boxes.append(box)
  190. txts.append(txt)
  191. if txt in ["*", "###"]:
  192. txt_tags.append(True)
  193. else:
  194. txt_tags.append(False)
  195. boxes = np.array(boxes, dtype=np.float32)
  196. txt_tags = np.array(txt_tags, dtype=np.bool_)
  197. data["polys"] = boxes
  198. data["ignore_tags"] = txt_tags
  199. temp_texts = []
  200. for text in txts:
  201. text = text.lower()
  202. text = self.encode(text)
  203. if text is None:
  204. return None
  205. text = text + [padnum] * (self.max_text_len - len(text)) # use 36 to pad
  206. temp_texts.append(text)
  207. data["texts"] = np.array(temp_texts)
  208. return data
  209. class E2ELabelEncodeTrain(object):
  210. def __init__(self, **kwargs):
  211. pass
  212. def __call__(self, data):
  213. import json
  214. label = data["label"]
  215. label = json.loads(label)
  216. nBox = len(label)
  217. boxes, txts, txt_tags = [], [], []
  218. for bno in range(0, nBox):
  219. box = label[bno]["points"]
  220. txt = label[bno]["transcription"]
  221. boxes.append(box)
  222. txts.append(txt)
  223. if txt in ["*", "###"]:
  224. txt_tags.append(True)
  225. else:
  226. txt_tags.append(False)
  227. boxes = np.array(boxes, dtype=np.float32)
  228. txt_tags = np.array(txt_tags, dtype=np.bool_)
  229. data["polys"] = boxes
  230. data["texts"] = txts
  231. data["ignore_tags"] = txt_tags
  232. return data
  233. class KieLabelEncode(object):
  234. def __init__(
  235. self, character_dict_path, class_path, norm=10, directed=False, **kwargs
  236. ):
  237. super(KieLabelEncode, self).__init__()
  238. self.dict = dict({"": 0})
  239. self.label2classid_map = dict()
  240. with open(character_dict_path, "r", encoding="utf-8") as fr:
  241. idx = 1
  242. for line in fr:
  243. char = line.strip()
  244. self.dict[char] = idx
  245. idx += 1
  246. with open(class_path, "r") as fin:
  247. lines = fin.readlines()
  248. for idx, line in enumerate(lines):
  249. line = line.strip("\n")
  250. self.label2classid_map[line] = idx
  251. self.norm = norm
  252. self.directed = directed
  253. def compute_relation(self, boxes):
  254. """Compute relation between every two boxes."""
  255. x1s, y1s = boxes[:, 0:1], boxes[:, 1:2]
  256. x2s, y2s = boxes[:, 4:5], boxes[:, 5:6]
  257. ws, hs = x2s - x1s + 1, np.maximum(y2s - y1s + 1, 1)
  258. dxs = (x1s[:, 0][None] - x1s) / self.norm
  259. dys = (y1s[:, 0][None] - y1s) / self.norm
  260. xhhs, xwhs = hs[:, 0][None] / hs, ws[:, 0][None] / hs
  261. whs = ws / hs + np.zeros_like(xhhs)
  262. relations = np.stack([dxs, dys, whs, xhhs, xwhs], -1)
  263. bboxes = np.concatenate([x1s, y1s, x2s, y2s], -1).astype(np.float32)
  264. return relations, bboxes
  265. def pad_text_indices(self, text_inds):
  266. """Pad text index to same length."""
  267. max_len = 300
  268. recoder_len = max([len(text_ind) for text_ind in text_inds])
  269. padded_text_inds = -np.ones((len(text_inds), max_len), np.int32)
  270. for idx, text_ind in enumerate(text_inds):
  271. padded_text_inds[idx, : len(text_ind)] = np.array(text_ind)
  272. return padded_text_inds, recoder_len
  273. def list_to_numpy(self, ann_infos):
  274. """Convert bboxes, relations, texts and labels to ndarray."""
  275. boxes, text_inds = ann_infos["points"], ann_infos["text_inds"]
  276. boxes = np.array(boxes, np.int32)
  277. relations, bboxes = self.compute_relation(boxes)
  278. labels = ann_infos.get("labels", None)
  279. if labels is not None:
  280. labels = np.array(labels, np.int32)
  281. edges = ann_infos.get("edges", None)
  282. if edges is not None:
  283. labels = labels[:, None]
  284. edges = np.array(edges)
  285. edges = (edges[:, None] == edges[None, :]).astype(np.int32)
  286. if self.directed:
  287. edges = (edges & labels == 1).astype(np.int32)
  288. np.fill_diagonal(edges, -1)
  289. labels = np.concatenate([labels, edges], -1)
  290. padded_text_inds, recoder_len = self.pad_text_indices(text_inds)
  291. max_num = 300
  292. temp_bboxes = np.zeros([max_num, 4])
  293. h, _ = bboxes.shape
  294. temp_bboxes[:h, :] = bboxes
  295. temp_relations = np.zeros([max_num, max_num, 5])
  296. temp_relations[:h, :h, :] = relations
  297. temp_padded_text_inds = np.zeros([max_num, max_num])
  298. temp_padded_text_inds[:h, :] = padded_text_inds
  299. temp_labels = np.zeros([max_num, max_num])
  300. temp_labels[:h, : h + 1] = labels
  301. tag = np.array([h, recoder_len])
  302. return dict(
  303. image=ann_infos["image"],
  304. points=temp_bboxes,
  305. relations=temp_relations,
  306. texts=temp_padded_text_inds,
  307. labels=temp_labels,
  308. tag=tag,
  309. )
  310. def convert_canonical(self, points_x, points_y):
  311. assert len(points_x) == 4
  312. assert len(points_y) == 4
  313. points = [Point(points_x[i], points_y[i]) for i in range(4)]
  314. polygon = Polygon([(p.x, p.y) for p in points])
  315. min_x, min_y, _, _ = polygon.bounds
  316. points_to_lefttop = [
  317. LineString([points[i], Point(min_x, min_y)]) for i in range(4)
  318. ]
  319. distances = np.array([line.length for line in points_to_lefttop])
  320. sort_dist_idx = np.argsort(distances)
  321. lefttop_idx = sort_dist_idx[0]
  322. if lefttop_idx == 0:
  323. point_orders = [0, 1, 2, 3]
  324. elif lefttop_idx == 1:
  325. point_orders = [1, 2, 3, 0]
  326. elif lefttop_idx == 2:
  327. point_orders = [2, 3, 0, 1]
  328. else:
  329. point_orders = [3, 0, 1, 2]
  330. sorted_points_x = [points_x[i] for i in point_orders]
  331. sorted_points_y = [points_y[j] for j in point_orders]
  332. return sorted_points_x, sorted_points_y
  333. def sort_vertex(self, points_x, points_y):
  334. assert len(points_x) == 4
  335. assert len(points_y) == 4
  336. x = np.array(points_x)
  337. y = np.array(points_y)
  338. center_x = np.sum(x) * 0.25
  339. center_y = np.sum(y) * 0.25
  340. x_arr = np.array(x - center_x)
  341. y_arr = np.array(y - center_y)
  342. angle = np.arctan2(y_arr, x_arr) * 180.0 / np.pi
  343. sort_idx = np.argsort(angle)
  344. sorted_points_x, sorted_points_y = [], []
  345. for i in range(4):
  346. sorted_points_x.append(points_x[sort_idx[i]])
  347. sorted_points_y.append(points_y[sort_idx[i]])
  348. return self.convert_canonical(sorted_points_x, sorted_points_y)
  349. def __call__(self, data):
  350. import json
  351. label = data["label"]
  352. annotations = json.loads(label)
  353. boxes, texts, text_inds, labels, edges = [], [], [], [], []
  354. for ann in annotations:
  355. box = ann["points"]
  356. x_list = [box[i][0] for i in range(4)]
  357. y_list = [box[i][1] for i in range(4)]
  358. sorted_x_list, sorted_y_list = self.sort_vertex(x_list, y_list)
  359. sorted_box = []
  360. for x, y in zip(sorted_x_list, sorted_y_list):
  361. sorted_box.append(x)
  362. sorted_box.append(y)
  363. boxes.append(sorted_box)
  364. text = ann["transcription"]
  365. texts.append(ann["transcription"])
  366. text_ind = [self.dict[c] for c in text if c in self.dict]
  367. text_inds.append(text_ind)
  368. if "label" in ann.keys():
  369. labels.append(self.label2classid_map[ann["label"]])
  370. elif "key_cls" in ann.keys():
  371. labels.append(ann["key_cls"])
  372. else:
  373. raise ValueError(
  374. "Cannot found 'key_cls' in ann.keys(), please check your training annotation."
  375. )
  376. edges.append(ann.get("edge", 0))
  377. ann_infos = dict(
  378. image=data["image"],
  379. points=boxes,
  380. texts=texts,
  381. text_inds=text_inds,
  382. edges=edges,
  383. labels=labels,
  384. )
  385. return self.list_to_numpy(ann_infos)
  386. class AttnLabelEncode(BaseRecLabelEncode):
  387. """Convert between text-label and text-index"""
  388. def __init__(
  389. self, max_text_length, character_dict_path=None, use_space_char=False, **kwargs
  390. ):
  391. super(AttnLabelEncode, self).__init__(
  392. max_text_length, character_dict_path, use_space_char
  393. )
  394. def add_special_char(self, dict_character):
  395. self.beg_str = "sos"
  396. self.end_str = "eos"
  397. dict_character = [self.beg_str] + dict_character + [self.end_str]
  398. return dict_character
  399. def __call__(self, data):
  400. text = data["label"]
  401. text = self.encode(text)
  402. if text is None:
  403. return None
  404. if len(text) >= self.max_text_len:
  405. return None
  406. data["length"] = np.array(len(text))
  407. text = (
  408. [0]
  409. + text
  410. + [len(self.character) - 1]
  411. + [0] * (self.max_text_len - len(text) - 2)
  412. )
  413. data["label"] = np.array(text)
  414. return data
  415. def get_ignored_tokens(self):
  416. beg_idx = self.get_beg_end_flag_idx("beg")
  417. end_idx = self.get_beg_end_flag_idx("end")
  418. return [beg_idx, end_idx]
  419. def get_beg_end_flag_idx(self, beg_or_end):
  420. if beg_or_end == "beg":
  421. idx = np.array(self.dict[self.beg_str])
  422. elif beg_or_end == "end":
  423. idx = np.array(self.dict[self.end_str])
  424. else:
  425. assert False, "Unsupported type %s in get_beg_end_flag_idx" % beg_or_end
  426. return idx
  427. class RFLLabelEncode(BaseRecLabelEncode):
  428. """Convert between text-label and text-index"""
  429. def __init__(
  430. self, max_text_length, character_dict_path=None, use_space_char=False, **kwargs
  431. ):
  432. super(RFLLabelEncode, self).__init__(
  433. max_text_length, character_dict_path, use_space_char
  434. )
  435. def add_special_char(self, dict_character):
  436. self.beg_str = "sos"
  437. self.end_str = "eos"
  438. dict_character = [self.beg_str] + dict_character + [self.end_str]
  439. return dict_character
  440. def encode_cnt(self, text):
  441. cnt_label = [0.0] * len(self.character)
  442. for char_ in text:
  443. cnt_label[char_] += 1
  444. return np.array(cnt_label)
  445. def __call__(self, data):
  446. text = data["label"]
  447. text = self.encode(text)
  448. if text is None:
  449. return None
  450. if len(text) >= self.max_text_len:
  451. return None
  452. cnt_label = self.encode_cnt(text)
  453. data["length"] = np.array(len(text))
  454. text = (
  455. [0]
  456. + text
  457. + [len(self.character) - 1]
  458. + [0] * (self.max_text_len - len(text) - 2)
  459. )
  460. if len(text) != self.max_text_len:
  461. return None
  462. data["label"] = np.array(text)
  463. data["cnt_label"] = cnt_label
  464. return data
  465. def get_ignored_tokens(self):
  466. beg_idx = self.get_beg_end_flag_idx("beg")
  467. end_idx = self.get_beg_end_flag_idx("end")
  468. return [beg_idx, end_idx]
  469. def get_beg_end_flag_idx(self, beg_or_end):
  470. if beg_or_end == "beg":
  471. idx = np.array(self.dict[self.beg_str])
  472. elif beg_or_end == "end":
  473. idx = np.array(self.dict[self.end_str])
  474. else:
  475. assert False, "Unsupported type %s in get_beg_end_flag_idx" % beg_or_end
  476. return idx
  477. class SEEDLabelEncode(BaseRecLabelEncode):
  478. """Convert between text-label and text-index"""
  479. def __init__(
  480. self, max_text_length, character_dict_path=None, use_space_char=False, **kwargs
  481. ):
  482. super(SEEDLabelEncode, self).__init__(
  483. max_text_length, character_dict_path, use_space_char
  484. )
  485. def add_special_char(self, dict_character):
  486. self.padding = "padding"
  487. self.end_str = "eos"
  488. self.unknown = "unknown"
  489. dict_character = dict_character + [self.end_str, self.padding, self.unknown]
  490. return dict_character
  491. def __call__(self, data):
  492. text = data["label"]
  493. text = self.encode(text)
  494. if text is None:
  495. return None
  496. if len(text) >= self.max_text_len:
  497. return None
  498. data["length"] = np.array(len(text)) + 1 # conclude eos
  499. text = (
  500. text
  501. + [len(self.character) - 3]
  502. + [len(self.character) - 2] * (self.max_text_len - len(text) - 1)
  503. )
  504. data["label"] = np.array(text)
  505. return data
  506. class SRNLabelEncode(BaseRecLabelEncode):
  507. """Convert between text-label and text-index"""
  508. def __init__(
  509. self,
  510. max_text_length=25,
  511. character_dict_path=None,
  512. use_space_char=False,
  513. **kwargs,
  514. ):
  515. super(SRNLabelEncode, self).__init__(
  516. max_text_length, character_dict_path, use_space_char
  517. )
  518. def add_special_char(self, dict_character):
  519. dict_character = dict_character + [self.beg_str, self.end_str]
  520. return dict_character
  521. def __call__(self, data):
  522. text = data["label"]
  523. text = self.encode(text)
  524. char_num = len(self.character)
  525. if text is None:
  526. return None
  527. if len(text) > self.max_text_len:
  528. return None
  529. data["length"] = np.array(len(text))
  530. text = text + [char_num - 1] * (self.max_text_len - len(text))
  531. data["label"] = np.array(text)
  532. return data
  533. def get_ignored_tokens(self):
  534. beg_idx = self.get_beg_end_flag_idx("beg")
  535. end_idx = self.get_beg_end_flag_idx("end")
  536. return [beg_idx, end_idx]
  537. def get_beg_end_flag_idx(self, beg_or_end):
  538. if beg_or_end == "beg":
  539. idx = np.array(self.dict[self.beg_str])
  540. elif beg_or_end == "end":
  541. idx = np.array(self.dict[self.end_str])
  542. else:
  543. assert False, "Unsupported type %s in get_beg_end_flag_idx" % beg_or_end
  544. return idx
  545. class TableLabelEncode(AttnLabelEncode):
  546. """Convert between text-label and text-index"""
  547. def __init__(
  548. self,
  549. max_text_length,
  550. character_dict_path,
  551. replace_empty_cell_token=False,
  552. merge_no_span_structure=False,
  553. learn_empty_box=False,
  554. loc_reg_num=4,
  555. **kwargs,
  556. ):
  557. self.max_text_len = max_text_length
  558. self.lower = False
  559. self.learn_empty_box = learn_empty_box
  560. self.merge_no_span_structure = merge_no_span_structure
  561. self.replace_empty_cell_token = replace_empty_cell_token
  562. dict_character = []
  563. with open(character_dict_path, "rb") as fin:
  564. lines = fin.readlines()
  565. for line in lines:
  566. line = line.decode("utf-8").strip("\n").strip("\r\n")
  567. dict_character.append(line)
  568. if self.merge_no_span_structure:
  569. if "<td></td>" not in dict_character:
  570. dict_character.append("<td></td>")
  571. if "<td>" in dict_character:
  572. dict_character.remove("<td>")
  573. dict_character = self.add_special_char(dict_character)
  574. self.dict = {}
  575. for i, char in enumerate(dict_character):
  576. self.dict[char] = i
  577. self.idx2char = {v: k for k, v in self.dict.items()}
  578. self.character = dict_character
  579. self.loc_reg_num = loc_reg_num
  580. self.pad_idx = self.dict[self.beg_str]
  581. self.start_idx = self.dict[self.beg_str]
  582. self.end_idx = self.dict[self.end_str]
  583. self.td_token = ["<td>", "<td", "<eb></eb>", "<td></td>"]
  584. self.empty_bbox_token_dict = {
  585. "[]": "<eb></eb>",
  586. "[' ']": "<eb1></eb1>",
  587. "['<b>', ' ', '</b>']": "<eb2></eb2>",
  588. "['\\u2028', '\\u2028']": "<eb3></eb3>",
  589. "['<sup>', ' ', '</sup>']": "<eb4></eb4>",
  590. "['<b>', '</b>']": "<eb5></eb5>",
  591. "['<i>', ' ', '</i>']": "<eb6></eb6>",
  592. "['<b>', '<i>', '</i>', '</b>']": "<eb7></eb7>",
  593. "['<b>', '<i>', ' ', '</i>', '</b>']": "<eb8></eb8>",
  594. "['<i>', '</i>']": "<eb9></eb9>",
  595. "['<b>', ' ', '\\u2028', ' ', '\\u2028', ' ', '</b>']": "<eb10></eb10>",
  596. }
  597. @property
  598. def _max_text_len(self):
  599. return self.max_text_len + 2
  600. def __call__(self, data):
  601. cells = data["cells"]
  602. structure = data["structure"]
  603. if self.merge_no_span_structure:
  604. structure = self._merge_no_span_structure(structure)
  605. if self.replace_empty_cell_token:
  606. structure = self._replace_empty_cell_token(structure, cells)
  607. # remove empty token and add " " to span token
  608. new_structure = []
  609. for token in structure:
  610. if token != "":
  611. if "span" in token and token[0] != " ":
  612. token = " " + token
  613. new_structure.append(token)
  614. # encode structure
  615. structure = self.encode(new_structure)
  616. if structure is None:
  617. return None
  618. data["length"] = len(structure)
  619. structure = [self.start_idx] + structure + [self.end_idx] # add sos abd eos
  620. structure = structure + [self.pad_idx] * (
  621. self._max_text_len - len(structure)
  622. ) # pad
  623. structure = np.array(structure)
  624. data["structure"] = structure
  625. if len(structure) > self._max_text_len:
  626. return None
  627. # encode box
  628. bboxes = np.zeros((self._max_text_len, self.loc_reg_num), dtype=np.float32)
  629. bbox_masks = np.zeros((self._max_text_len, 1), dtype=np.float32)
  630. bbox_idx = 0
  631. for i, token in enumerate(structure):
  632. if self.idx2char[token] in self.td_token:
  633. if "bbox" in cells[bbox_idx] and len(cells[bbox_idx]["tokens"]) > 0:
  634. bbox = cells[bbox_idx]["bbox"].copy()
  635. bbox = np.array(bbox, dtype=np.float32).reshape(-1)
  636. bboxes[i] = bbox
  637. bbox_masks[i] = 1.0
  638. if self.learn_empty_box:
  639. bbox_masks[i] = 1.0
  640. bbox_idx += 1
  641. data["bboxes"] = bboxes
  642. data["bbox_masks"] = bbox_masks
  643. return data
  644. def _merge_no_span_structure(self, structure):
  645. """
  646. This code is refer from:
  647. https://github.com/JiaquanYe/TableMASTER-mmocr/blob/master/table_recognition/data_preprocess.py
  648. """
  649. new_structure = []
  650. i = 0
  651. while i < len(structure):
  652. token = structure[i]
  653. if token == "<td>":
  654. token = "<td></td>"
  655. i += 1
  656. new_structure.append(token)
  657. i += 1
  658. return new_structure
  659. def _replace_empty_cell_token(self, token_list, cells):
  660. """
  661. This fun code is refer from:
  662. https://github.com/JiaquanYe/TableMASTER-mmocr/blob/master/table_recognition/data_preprocess.py
  663. """
  664. bbox_idx = 0
  665. add_empty_bbox_token_list = []
  666. for token in token_list:
  667. if token in ["<td></td>", "<td", "<td>"]:
  668. if "bbox" not in cells[bbox_idx].keys():
  669. content = str(cells[bbox_idx]["tokens"])
  670. token = self.empty_bbox_token_dict[content]
  671. add_empty_bbox_token_list.append(token)
  672. bbox_idx += 1
  673. else:
  674. add_empty_bbox_token_list.append(token)
  675. return add_empty_bbox_token_list
  676. class TableMasterLabelEncode(TableLabelEncode):
  677. """Convert between text-label and text-index"""
  678. def __init__(
  679. self,
  680. max_text_length,
  681. character_dict_path,
  682. replace_empty_cell_token=False,
  683. merge_no_span_structure=False,
  684. learn_empty_box=False,
  685. loc_reg_num=4,
  686. **kwargs,
  687. ):
  688. super(TableMasterLabelEncode, self).__init__(
  689. max_text_length,
  690. character_dict_path,
  691. replace_empty_cell_token,
  692. merge_no_span_structure,
  693. learn_empty_box,
  694. loc_reg_num,
  695. **kwargs,
  696. )
  697. self.pad_idx = self.dict[self.pad_str]
  698. self.unknown_idx = self.dict[self.unknown_str]
  699. @property
  700. def _max_text_len(self):
  701. return self.max_text_len
  702. def add_special_char(self, dict_character):
  703. self.beg_str = "<SOS>"
  704. self.end_str = "<EOS>"
  705. self.unknown_str = "<UKN>"
  706. self.pad_str = "<PAD>"
  707. dict_character = dict_character
  708. dict_character = dict_character + [
  709. self.unknown_str,
  710. self.beg_str,
  711. self.end_str,
  712. self.pad_str,
  713. ]
  714. return dict_character
  715. class TableBoxEncode(object):
  716. def __init__(self, in_box_format="xyxy", out_box_format="xyxy", **kwargs):
  717. assert out_box_format in ["xywh", "xyxy", "xyxyxyxy"]
  718. self.in_box_format = in_box_format
  719. self.out_box_format = out_box_format
  720. def __call__(self, data):
  721. img_height, img_width = data["image"].shape[:2]
  722. bboxes = data["bboxes"]
  723. if self.in_box_format != self.out_box_format:
  724. if self.out_box_format == "xywh":
  725. if self.in_box_format == "xyxyxyxy":
  726. bboxes = self.xyxyxyxy2xywh(bboxes)
  727. elif self.in_box_format == "xyxy":
  728. bboxes = self.xyxy2xywh(bboxes)
  729. bboxes[:, 0::2] /= img_width
  730. bboxes[:, 1::2] /= img_height
  731. data["bboxes"] = bboxes
  732. return data
  733. def xyxyxyxy2xywh(self, boxes):
  734. new_bboxes = np.zeros([len(boxes), 4])
  735. new_bboxes[:, 0] = boxes[:, 0::2].min() # x1
  736. new_bboxes[:, 1] = boxes[:, 1::2].min() # y1
  737. new_bboxes[:, 2] = boxes[:, 0::2].max() - new_bboxes[:, 0] # w
  738. new_bboxes[:, 3] = boxes[:, 1::2].max() - new_bboxes[:, 1] # h
  739. return new_bboxes
  740. def xyxy2xywh(self, bboxes):
  741. new_bboxes = np.empty_like(bboxes)
  742. new_bboxes[:, 0] = (bboxes[:, 0] + bboxes[:, 2]) / 2 # x center
  743. new_bboxes[:, 1] = (bboxes[:, 1] + bboxes[:, 3]) / 2 # y center
  744. new_bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 0] # width
  745. new_bboxes[:, 3] = bboxes[:, 3] - bboxes[:, 1] # height
  746. return new_bboxes
  747. class SARLabelEncode(BaseRecLabelEncode):
  748. """Convert between text-label and text-index"""
  749. def __init__(
  750. self, max_text_length, character_dict_path=None, use_space_char=False, **kwargs
  751. ):
  752. super(SARLabelEncode, self).__init__(
  753. max_text_length, character_dict_path, use_space_char
  754. )
  755. def add_special_char(self, dict_character):
  756. beg_end_str = "<BOS/EOS>"
  757. unknown_str = "<UKN>"
  758. padding_str = "<PAD>"
  759. dict_character = dict_character + [unknown_str]
  760. self.unknown_idx = len(dict_character) - 1
  761. dict_character = dict_character + [beg_end_str]
  762. self.start_idx = len(dict_character) - 1
  763. self.end_idx = len(dict_character) - 1
  764. dict_character = dict_character + [padding_str]
  765. self.padding_idx = len(dict_character) - 1
  766. return dict_character
  767. def __call__(self, data):
  768. text = data["label"]
  769. text = self.encode(text)
  770. if text is None:
  771. return None
  772. if len(text) >= self.max_text_len - 1:
  773. return None
  774. data["length"] = np.array(len(text))
  775. target = [self.start_idx] + text + [self.end_idx]
  776. padded_text = [self.padding_idx for _ in range(self.max_text_len)]
  777. padded_text[: len(target)] = target
  778. data["label"] = np.array(padded_text)
  779. return data
  780. def get_ignored_tokens(self):
  781. return [self.padding_idx]
  782. class SATRNLabelEncode(BaseRecLabelEncode):
  783. """Convert between text-label and text-index"""
  784. def __init__(
  785. self,
  786. max_text_length,
  787. character_dict_path=None,
  788. use_space_char=False,
  789. lower=False,
  790. **kwargs,
  791. ):
  792. super(SATRNLabelEncode, self).__init__(
  793. max_text_length, character_dict_path, use_space_char
  794. )
  795. self.lower = lower
  796. def add_special_char(self, dict_character):
  797. beg_end_str = "<BOS/EOS>"
  798. unknown_str = "<UKN>"
  799. padding_str = "<PAD>"
  800. dict_character = dict_character + [unknown_str]
  801. self.unknown_idx = len(dict_character) - 1
  802. dict_character = dict_character + [beg_end_str]
  803. self.start_idx = len(dict_character) - 1
  804. self.end_idx = len(dict_character) - 1
  805. dict_character = dict_character + [padding_str]
  806. self.padding_idx = len(dict_character) - 1
  807. return dict_character
  808. def encode(self, text):
  809. if self.lower:
  810. text = text.lower()
  811. text_list = []
  812. for char in text:
  813. text_list.append(self.dict.get(char, self.unknown_idx))
  814. if len(text_list) == 0:
  815. return None
  816. return text_list
  817. def __call__(self, data):
  818. text = data["label"]
  819. text = self.encode(text)
  820. if text is None:
  821. return None
  822. data["length"] = np.array(len(text))
  823. target = [self.start_idx] + text + [self.end_idx]
  824. padded_text = [self.padding_idx for _ in range(self.max_text_len)]
  825. if len(target) > self.max_text_len:
  826. padded_text = target[: self.max_text_len]
  827. else:
  828. padded_text[: len(target)] = target
  829. data["label"] = np.array(padded_text)
  830. return data
  831. def get_ignored_tokens(self):
  832. return [self.padding_idx]
  833. class PRENLabelEncode(BaseRecLabelEncode):
  834. def __init__(
  835. self, max_text_length, character_dict_path, use_space_char=False, **kwargs
  836. ):
  837. super(PRENLabelEncode, self).__init__(
  838. max_text_length, character_dict_path, use_space_char
  839. )
  840. def add_special_char(self, dict_character):
  841. padding_str = "<PAD>" # 0
  842. end_str = "<EOS>" # 1
  843. unknown_str = "<UNK>" # 2
  844. dict_character = [padding_str, end_str, unknown_str] + dict_character
  845. self.padding_idx = 0
  846. self.end_idx = 1
  847. self.unknown_idx = 2
  848. return dict_character
  849. def encode(self, text):
  850. if len(text) == 0 or len(text) >= self.max_text_len:
  851. return None
  852. if self.lower:
  853. text = text.lower()
  854. text_list = []
  855. for char in text:
  856. if char not in self.dict:
  857. text_list.append(self.unknown_idx)
  858. else:
  859. text_list.append(self.dict[char])
  860. text_list.append(self.end_idx)
  861. if len(text_list) < self.max_text_len:
  862. text_list += [self.padding_idx] * (self.max_text_len - len(text_list))
  863. return text_list
  864. def __call__(self, data):
  865. text = data["label"]
  866. encoded_text = self.encode(text)
  867. if encoded_text is None:
  868. return None
  869. data["label"] = np.array(encoded_text)
  870. return data
  871. class VQATokenLabelEncode(object):
  872. """
  873. Label encode for NLP VQA methods
  874. """
  875. def __init__(
  876. self,
  877. class_path,
  878. contains_re=False,
  879. add_special_ids=False,
  880. algorithm="LayoutXLM",
  881. use_textline_bbox_info=True,
  882. order_method=None,
  883. infer_mode=False,
  884. ocr_engine=None,
  885. **kwargs,
  886. ):
  887. super(VQATokenLabelEncode, self).__init__()
  888. from paddlenlp.transformers import (
  889. LayoutXLMTokenizer,
  890. LayoutLMTokenizer,
  891. LayoutLMv2Tokenizer,
  892. )
  893. from ppocr.utils.utility import load_vqa_bio_label_maps
  894. tokenizer_dict = {
  895. "LayoutXLM": {
  896. "class": LayoutXLMTokenizer,
  897. "pretrained_model": "layoutxlm-base-uncased",
  898. },
  899. "LayoutLM": {
  900. "class": LayoutLMTokenizer,
  901. "pretrained_model": "layoutlm-base-uncased",
  902. },
  903. "LayoutLMv2": {
  904. "class": LayoutLMv2Tokenizer,
  905. "pretrained_model": "layoutlmv2-base-uncased",
  906. },
  907. }
  908. self.contains_re = contains_re
  909. tokenizer_config = tokenizer_dict[algorithm]
  910. self.tokenizer = tokenizer_config["class"].from_pretrained(
  911. tokenizer_config["pretrained_model"]
  912. )
  913. self.label2id_map, id2label_map = load_vqa_bio_label_maps(class_path)
  914. self.add_special_ids = add_special_ids
  915. self.infer_mode = infer_mode
  916. self.ocr_engine = ocr_engine
  917. self.use_textline_bbox_info = use_textline_bbox_info
  918. self.order_method = order_method
  919. assert self.order_method in [None, "tb-yx"]
  920. def split_bbox(self, bbox, text, tokenizer):
  921. words = text.split()
  922. token_bboxes = []
  923. curr_word_idx = 0
  924. x1, y1, x2, y2 = bbox
  925. unit_w = (x2 - x1) / len(text)
  926. for idx, word in enumerate(words):
  927. curr_w = len(word) * unit_w
  928. word_bbox = [x1, y1, x1 + curr_w, y2]
  929. token_bboxes.extend([word_bbox] * len(tokenizer.tokenize(word)))
  930. x1 += (len(word) + 1) * unit_w
  931. return token_bboxes
  932. def filter_empty_contents(self, ocr_info):
  933. """
  934. find out the empty texts and remove the links
  935. """
  936. new_ocr_info = []
  937. empty_index = []
  938. for idx, info in enumerate(ocr_info):
  939. if len(info["transcription"]) > 0:
  940. new_ocr_info.append(copy.deepcopy(info))
  941. else:
  942. empty_index.append(info["id"])
  943. for idx, info in enumerate(new_ocr_info):
  944. new_link = []
  945. for link in info["linking"]:
  946. if link[0] in empty_index or link[1] in empty_index:
  947. continue
  948. new_link.append(link)
  949. new_ocr_info[idx]["linking"] = new_link
  950. return new_ocr_info
  951. def __call__(self, data):
  952. # load bbox and label info
  953. ocr_info = self._load_ocr_info(data)
  954. for idx in range(len(ocr_info)):
  955. if "bbox" not in ocr_info[idx]:
  956. ocr_info[idx]["bbox"] = self.trans_poly_to_bbox(ocr_info[idx]["points"])
  957. if self.order_method == "tb-yx":
  958. ocr_info = order_by_tbyx(ocr_info)
  959. # for re
  960. train_re = self.contains_re and not self.infer_mode
  961. if train_re:
  962. ocr_info = self.filter_empty_contents(ocr_info)
  963. height, width, _ = data["image"].shape
  964. words_list = []
  965. bbox_list = []
  966. input_ids_list = []
  967. token_type_ids_list = []
  968. segment_offset_id = []
  969. gt_label_list = []
  970. entities = []
  971. if train_re:
  972. relations = []
  973. id2label = {}
  974. entity_id_to_index_map = {}
  975. empty_entity = set()
  976. data["ocr_info"] = copy.deepcopy(ocr_info)
  977. for info in ocr_info:
  978. text = info["transcription"]
  979. if len(text) <= 0:
  980. continue
  981. if train_re:
  982. # for re
  983. if len(text) == 0:
  984. empty_entity.add(info["id"])
  985. continue
  986. id2label[info["id"]] = info["label"]
  987. relations.extend([tuple(sorted(l)) for l in info["linking"]])
  988. # smooth_box
  989. info["bbox"] = self.trans_poly_to_bbox(info["points"])
  990. encode_res = self.tokenizer.encode(
  991. text,
  992. pad_to_max_seq_len=False,
  993. return_attention_mask=True,
  994. return_token_type_ids=True,
  995. )
  996. if not self.add_special_ids:
  997. # TODO: use tok.all_special_ids to remove
  998. encode_res["input_ids"] = encode_res["input_ids"][1:-1]
  999. encode_res["token_type_ids"] = encode_res["token_type_ids"][1:-1]
  1000. encode_res["attention_mask"] = encode_res["attention_mask"][1:-1]
  1001. if self.use_textline_bbox_info:
  1002. bbox = [info["bbox"]] * len(encode_res["input_ids"])
  1003. else:
  1004. bbox = self.split_bbox(
  1005. info["bbox"], info["transcription"], self.tokenizer
  1006. )
  1007. if len(bbox) <= 0:
  1008. continue
  1009. bbox = self._smooth_box(bbox, height, width)
  1010. if self.add_special_ids:
  1011. bbox.insert(0, [0, 0, 0, 0])
  1012. bbox.append([0, 0, 0, 0])
  1013. # parse label
  1014. if not self.infer_mode:
  1015. label = info["label"]
  1016. gt_label = self._parse_label(label, encode_res)
  1017. # construct entities for re
  1018. if train_re:
  1019. if gt_label[0] != self.label2id_map["O"]:
  1020. entity_id_to_index_map[info["id"]] = len(entities)
  1021. label = label.upper()
  1022. entities.append(
  1023. {
  1024. "start": len(input_ids_list),
  1025. "end": len(input_ids_list) + len(encode_res["input_ids"]),
  1026. "label": label.upper(),
  1027. }
  1028. )
  1029. else:
  1030. entities.append(
  1031. {
  1032. "start": len(input_ids_list),
  1033. "end": len(input_ids_list) + len(encode_res["input_ids"]),
  1034. "label": "O",
  1035. }
  1036. )
  1037. input_ids_list.extend(encode_res["input_ids"])
  1038. token_type_ids_list.extend(encode_res["token_type_ids"])
  1039. bbox_list.extend(bbox)
  1040. words_list.append(text)
  1041. segment_offset_id.append(len(input_ids_list))
  1042. if not self.infer_mode:
  1043. gt_label_list.extend(gt_label)
  1044. data["input_ids"] = input_ids_list
  1045. data["token_type_ids"] = token_type_ids_list
  1046. data["bbox"] = bbox_list
  1047. data["attention_mask"] = [1] * len(input_ids_list)
  1048. data["labels"] = gt_label_list
  1049. data["segment_offset_id"] = segment_offset_id
  1050. data["tokenizer_params"] = dict(
  1051. padding_side=self.tokenizer.padding_side,
  1052. pad_token_type_id=self.tokenizer.pad_token_type_id,
  1053. pad_token_id=self.tokenizer.pad_token_id,
  1054. )
  1055. data["entities"] = entities
  1056. if train_re:
  1057. data["relations"] = relations
  1058. data["id2label"] = id2label
  1059. data["empty_entity"] = empty_entity
  1060. data["entity_id_to_index_map"] = entity_id_to_index_map
  1061. return data
  1062. def trans_poly_to_bbox(self, poly):
  1063. x1 = int(np.min([p[0] for p in poly]))
  1064. x2 = int(np.max([p[0] for p in poly]))
  1065. y1 = int(np.min([p[1] for p in poly]))
  1066. y2 = int(np.max([p[1] for p in poly]))
  1067. return [x1, y1, x2, y2]
  1068. def _load_ocr_info(self, data):
  1069. if self.infer_mode:
  1070. ocr_result = self.ocr_engine.ocr(data["image"], cls=False)[0]
  1071. ocr_info = []
  1072. for res in ocr_result:
  1073. ocr_info.append(
  1074. {
  1075. "transcription": res[1][0],
  1076. "bbox": self.trans_poly_to_bbox(res[0]),
  1077. "points": res[0],
  1078. }
  1079. )
  1080. return ocr_info
  1081. else:
  1082. info = data["label"]
  1083. # read text info
  1084. info_dict = json.loads(info)
  1085. return info_dict
  1086. def _smooth_box(self, bboxes, height, width):
  1087. bboxes = np.array(bboxes)
  1088. bboxes[:, 0] = bboxes[:, 0] * 1000 / width
  1089. bboxes[:, 2] = bboxes[:, 2] * 1000 / width
  1090. bboxes[:, 1] = bboxes[:, 1] * 1000 / height
  1091. bboxes[:, 3] = bboxes[:, 3] * 1000 / height
  1092. bboxes = bboxes.astype("int64").tolist()
  1093. return bboxes
  1094. def _parse_label(self, label, encode_res):
  1095. gt_label = []
  1096. if label.lower() in ["other", "others", "ignore"]:
  1097. gt_label.extend([0] * len(encode_res["input_ids"]))
  1098. else:
  1099. gt_label.append(self.label2id_map[("b-" + label).upper()])
  1100. gt_label.extend(
  1101. [self.label2id_map[("i-" + label).upper()]]
  1102. * (len(encode_res["input_ids"]) - 1)
  1103. )
  1104. return gt_label
  1105. class MultiLabelEncode(BaseRecLabelEncode):
  1106. def __init__(
  1107. self,
  1108. max_text_length,
  1109. character_dict_path=None,
  1110. use_space_char=False,
  1111. gtc_encode=None,
  1112. **kwargs,
  1113. ):
  1114. super(MultiLabelEncode, self).__init__(
  1115. max_text_length, character_dict_path, use_space_char
  1116. )
  1117. self.ctc_encode = CTCLabelEncode(
  1118. max_text_length, character_dict_path, use_space_char, **kwargs
  1119. )
  1120. self.gtc_encode_type = gtc_encode
  1121. if gtc_encode is None:
  1122. self.gtc_encode = SARLabelEncode(
  1123. max_text_length, character_dict_path, use_space_char, **kwargs
  1124. )
  1125. else:
  1126. self.gtc_encode = eval(gtc_encode)(
  1127. max_text_length, character_dict_path, use_space_char, **kwargs
  1128. )
  1129. def __call__(self, data):
  1130. data_ctc = copy.deepcopy(data)
  1131. data_gtc = copy.deepcopy(data)
  1132. data_out = dict()
  1133. data_out["img_path"] = data.get("img_path", None)
  1134. data_out["image"] = data["image"]
  1135. ctc = self.ctc_encode.__call__(data_ctc)
  1136. gtc = self.gtc_encode.__call__(data_gtc)
  1137. if ctc is None or gtc is None:
  1138. return None
  1139. data_out["label_ctc"] = ctc["label"]
  1140. if self.gtc_encode_type is not None:
  1141. data_out["label_gtc"] = gtc["label"]
  1142. else:
  1143. data_out["label_sar"] = gtc["label"]
  1144. data_out["length"] = ctc["length"]
  1145. return data_out
  1146. class NRTRLabelEncode(BaseRecLabelEncode):
  1147. """Convert between text-label and text-index"""
  1148. def __init__(
  1149. self, max_text_length, character_dict_path=None, use_space_char=False, **kwargs
  1150. ):
  1151. super(NRTRLabelEncode, self).__init__(
  1152. max_text_length, character_dict_path, use_space_char
  1153. )
  1154. def __call__(self, data):
  1155. text = data["label"]
  1156. text = self.encode(text)
  1157. if text is None:
  1158. return None
  1159. if len(text) >= self.max_text_len - 1:
  1160. return None
  1161. data["length"] = np.array(len(text))
  1162. text.insert(0, 2)
  1163. text.append(3)
  1164. text = text + [0] * (self.max_text_len - len(text))
  1165. data["label"] = np.array(text)
  1166. return data
  1167. def add_special_char(self, dict_character):
  1168. dict_character = ["blank", "<unk>", "<s>", "</s>"] + dict_character
  1169. return dict_character
  1170. class ParseQLabelEncode(BaseRecLabelEncode):
  1171. """Convert between text-label and text-index"""
  1172. BOS = "[B]"
  1173. EOS = "[E]"
  1174. PAD = "[P]"
  1175. def __init__(
  1176. self, max_text_length, character_dict_path=None, use_space_char=False, **kwargs
  1177. ):
  1178. super(ParseQLabelEncode, self).__init__(
  1179. max_text_length, character_dict_path, use_space_char
  1180. )
  1181. def __call__(self, data):
  1182. text = data["label"]
  1183. text = self.encode(text)
  1184. if text is None:
  1185. return None
  1186. if len(text) >= self.max_text_len - 2:
  1187. return None
  1188. data["length"] = np.array(len(text))
  1189. text = [self.dict[self.BOS]] + text + [self.dict[self.EOS]]
  1190. text = text + [self.dict[self.PAD]] * (self.max_text_len - len(text))
  1191. data["label"] = np.array(text)
  1192. return data
  1193. def add_special_char(self, dict_character):
  1194. dict_character = [self.EOS] + dict_character + [self.BOS, self.PAD]
  1195. return dict_character
  1196. class ViTSTRLabelEncode(BaseRecLabelEncode):
  1197. """Convert between text-label and text-index"""
  1198. def __init__(
  1199. self,
  1200. max_text_length,
  1201. character_dict_path=None,
  1202. use_space_char=False,
  1203. ignore_index=0,
  1204. **kwargs,
  1205. ):
  1206. super(ViTSTRLabelEncode, self).__init__(
  1207. max_text_length, character_dict_path, use_space_char
  1208. )
  1209. self.ignore_index = ignore_index
  1210. def __call__(self, data):
  1211. text = data["label"]
  1212. text = self.encode(text)
  1213. if text is None:
  1214. return None
  1215. if len(text) >= self.max_text_len:
  1216. return None
  1217. data["length"] = np.array(len(text))
  1218. text.insert(0, self.ignore_index)
  1219. text.append(1)
  1220. text = text + [self.ignore_index] * (self.max_text_len + 2 - len(text))
  1221. data["label"] = np.array(text)
  1222. return data
  1223. def add_special_char(self, dict_character):
  1224. dict_character = ["<s>", "</s>"] + dict_character
  1225. return dict_character
  1226. class ABINetLabelEncode(BaseRecLabelEncode):
  1227. """Convert between text-label and text-index"""
  1228. def __init__(
  1229. self,
  1230. max_text_length,
  1231. character_dict_path=None,
  1232. use_space_char=False,
  1233. ignore_index=100,
  1234. **kwargs,
  1235. ):
  1236. super(ABINetLabelEncode, self).__init__(
  1237. max_text_length, character_dict_path, use_space_char
  1238. )
  1239. self.ignore_index = ignore_index
  1240. def __call__(self, data):
  1241. text = data["label"]
  1242. text = self.encode(text)
  1243. if text is None:
  1244. return None
  1245. if len(text) >= self.max_text_len:
  1246. return None
  1247. data["length"] = np.array(len(text))
  1248. text.append(0)
  1249. text = text + [self.ignore_index] * (self.max_text_len + 1 - len(text))
  1250. data["label"] = np.array(text)
  1251. return data
  1252. def add_special_char(self, dict_character):
  1253. dict_character = ["</s>"] + dict_character
  1254. return dict_character
  1255. class SRLabelEncode(BaseRecLabelEncode):
  1256. def __init__(
  1257. self, max_text_length, character_dict_path=None, use_space_char=False, **kwargs
  1258. ):
  1259. super(SRLabelEncode, self).__init__(
  1260. max_text_length, character_dict_path, use_space_char
  1261. )
  1262. self.dic = {}
  1263. with open(character_dict_path, "r") as fin:
  1264. for line in fin.readlines():
  1265. line = line.strip()
  1266. character, sequence = line.split()
  1267. self.dic[character] = sequence
  1268. english_stroke_alphabet = "0123456789"
  1269. self.english_stroke_dict = {}
  1270. for index in range(len(english_stroke_alphabet)):
  1271. self.english_stroke_dict[english_stroke_alphabet[index]] = index
  1272. def encode(self, label):
  1273. stroke_sequence = ""
  1274. for character in label:
  1275. if character not in self.dic:
  1276. continue
  1277. else:
  1278. stroke_sequence += self.dic[character]
  1279. stroke_sequence += "0"
  1280. label = stroke_sequence
  1281. length = len(label)
  1282. input_tensor = np.zeros(self.max_text_len).astype("int64")
  1283. for j in range(length - 1):
  1284. input_tensor[j + 1] = self.english_stroke_dict[label[j]]
  1285. return length, input_tensor
  1286. def __call__(self, data):
  1287. text = data["label"]
  1288. length, input_tensor = self.encode(text)
  1289. data["length"] = length
  1290. data["input_tensor"] = input_tensor
  1291. if text is None:
  1292. return None
  1293. return data
  1294. class SPINLabelEncode(AttnLabelEncode):
  1295. """Convert between text-label and text-index"""
  1296. def __init__(
  1297. self,
  1298. max_text_length,
  1299. character_dict_path=None,
  1300. use_space_char=False,
  1301. lower=True,
  1302. **kwargs,
  1303. ):
  1304. super(SPINLabelEncode, self).__init__(
  1305. max_text_length, character_dict_path, use_space_char
  1306. )
  1307. self.lower = lower
  1308. def add_special_char(self, dict_character):
  1309. self.beg_str = "sos"
  1310. self.end_str = "eos"
  1311. dict_character = [self.beg_str] + [self.end_str] + dict_character
  1312. return dict_character
  1313. def __call__(self, data):
  1314. text = data["label"]
  1315. text = self.encode(text)
  1316. if text is None:
  1317. return None
  1318. if len(text) > self.max_text_len:
  1319. return None
  1320. data["length"] = np.array(len(text))
  1321. target = [0] + text + [1]
  1322. padded_text = [0 for _ in range(self.max_text_len + 2)]
  1323. padded_text[: len(target)] = target
  1324. data["label"] = np.array(padded_text)
  1325. return data
  1326. class VLLabelEncode(BaseRecLabelEncode):
  1327. """Convert between text-label and text-index"""
  1328. def __init__(
  1329. self, max_text_length, character_dict_path=None, use_space_char=False, **kwargs
  1330. ):
  1331. super(VLLabelEncode, self).__init__(
  1332. max_text_length, character_dict_path, use_space_char
  1333. )
  1334. self.dict = {}
  1335. for i, char in enumerate(self.character):
  1336. self.dict[char] = i
  1337. def __call__(self, data):
  1338. text = data["label"] # original string
  1339. # generate occluded text
  1340. len_str = len(text)
  1341. if len_str <= 0:
  1342. return None
  1343. change_num = 1
  1344. order = list(range(len_str))
  1345. change_id = sample(order, change_num)[0]
  1346. label_sub = text[change_id]
  1347. if change_id == (len_str - 1):
  1348. label_res = text[:change_id]
  1349. elif change_id == 0:
  1350. label_res = text[1:]
  1351. else:
  1352. label_res = text[:change_id] + text[change_id + 1 :]
  1353. data["label_res"] = label_res # remaining string
  1354. data["label_sub"] = label_sub # occluded character
  1355. data["label_id"] = change_id # character index
  1356. # encode label
  1357. text = self.encode(text)
  1358. if text is None:
  1359. return None
  1360. text = [i + 1 for i in text]
  1361. data["length"] = np.array(len(text))
  1362. text = text + [0] * (self.max_text_len - len(text))
  1363. data["label"] = np.array(text)
  1364. label_res = self.encode(label_res)
  1365. label_sub = self.encode(label_sub)
  1366. if label_res is None:
  1367. label_res = []
  1368. else:
  1369. label_res = [i + 1 for i in label_res]
  1370. if label_sub is None:
  1371. label_sub = []
  1372. else:
  1373. label_sub = [i + 1 for i in label_sub]
  1374. data["length_res"] = np.array(len(label_res))
  1375. data["length_sub"] = np.array(len(label_sub))
  1376. label_res = label_res + [0] * (self.max_text_len - len(label_res))
  1377. label_sub = label_sub + [0] * (self.max_text_len - len(label_sub))
  1378. data["label_res"] = np.array(label_res)
  1379. data["label_sub"] = np.array(label_sub)
  1380. return data
  1381. class CTLabelEncode(object):
  1382. def __init__(self, **kwargs):
  1383. pass
  1384. def __call__(self, data):
  1385. label = data["label"]
  1386. label = json.loads(label)
  1387. nBox = len(label)
  1388. boxes, txts = [], []
  1389. for bno in range(0, nBox):
  1390. box = label[bno]["points"]
  1391. box = np.array(box)
  1392. boxes.append(box)
  1393. txt = label[bno]["transcription"]
  1394. txts.append(txt)
  1395. if len(boxes) == 0:
  1396. return None
  1397. data["polys"] = boxes
  1398. data["texts"] = txts
  1399. return data
  1400. class CANLabelEncode(BaseRecLabelEncode):
  1401. def __init__(
  1402. self,
  1403. character_dict_path,
  1404. max_text_length=100,
  1405. use_space_char=False,
  1406. lower=True,
  1407. **kwargs,
  1408. ):
  1409. super(CANLabelEncode, self).__init__(
  1410. max_text_length, character_dict_path, use_space_char, lower
  1411. )
  1412. def encode(self, text_seq):
  1413. text_seq_encoded = []
  1414. for text in text_seq:
  1415. if text not in self.character:
  1416. continue
  1417. text_seq_encoded.append(self.dict.get(text))
  1418. if len(text_seq_encoded) == 0:
  1419. return None
  1420. return text_seq_encoded
  1421. def __call__(self, data):
  1422. label = data["label"]
  1423. if isinstance(label, str):
  1424. label = label.strip().split()
  1425. label.append(self.end_str)
  1426. data["label"] = self.encode(label)
  1427. return data
  1428. class CPPDLabelEncode(BaseRecLabelEncode):
  1429. """Convert between text-label and text-index"""
  1430. def __init__(
  1431. self,
  1432. max_text_length,
  1433. character_dict_path=None,
  1434. use_space_char=False,
  1435. ch=False,
  1436. ignore_index=100,
  1437. **kwargs,
  1438. ):
  1439. super(CPPDLabelEncode, self).__init__(
  1440. max_text_length, character_dict_path, use_space_char
  1441. )
  1442. self.ch = ch
  1443. self.ignore_index = ignore_index
  1444. def __call__(self, data):
  1445. text = data["label"]
  1446. if self.ch:
  1447. text, text_node_index, text_node_num = self.encodech(text)
  1448. if text is None:
  1449. return None
  1450. if len(text) > self.max_text_len:
  1451. return None
  1452. data["length"] = np.array(len(text))
  1453. text_pos_node = [1] * (len(text) + 1) + [0] * (
  1454. self.max_text_len - len(text)
  1455. )
  1456. text.append(0) # eos
  1457. text = text + [self.ignore_index] * (self.max_text_len + 1 - len(text))
  1458. data["label"] = np.array(text)
  1459. data["label_node"] = np.array(text_node_num + text_pos_node)
  1460. data["label_index"] = np.array(text_node_index)
  1461. return data
  1462. else:
  1463. text, text_char_node, ch_order = self.encode(text)
  1464. if text is None:
  1465. return None
  1466. if len(text) >= self.max_text_len:
  1467. return None
  1468. data["length"] = np.array(len(text))
  1469. text_pos_node = [1] * (len(text) + 1) + [0] * (
  1470. self.max_text_len - len(text)
  1471. )
  1472. text.append(0) # eos
  1473. text = text + [self.ignore_index] * (self.max_text_len + 1 - len(text))
  1474. data["label"] = np.array(text)
  1475. data["label_node"] = np.array(text_char_node + text_pos_node)
  1476. data["label_order"] = np.array(ch_order)
  1477. return data
  1478. def add_special_char(self, dict_character):
  1479. dict_character = ["</s>"] + dict_character
  1480. self.num_character = len(dict_character)
  1481. return dict_character
  1482. def encode(self, text):
  1483. """ """
  1484. if len(text) == 0 or len(text) > self.max_text_len:
  1485. return None, None, None
  1486. if self.lower:
  1487. text = text.lower()
  1488. text_node = [0 for _ in range(self.num_character)]
  1489. text_node[0] = 1
  1490. text_list = []
  1491. ch_order = []
  1492. order = 1
  1493. for char in text:
  1494. if char not in self.dict:
  1495. continue
  1496. text_list.append(self.dict[char])
  1497. text_node[self.dict[char]] += 1
  1498. ch_order.append([self.dict[char], text_node[self.dict[char]], order])
  1499. order += 1
  1500. no_ch_order = []
  1501. for char in self.character:
  1502. if char not in text:
  1503. no_ch_order.append([self.dict[char], 1, 0])
  1504. random.shuffle(no_ch_order)
  1505. ch_order = ch_order + no_ch_order
  1506. ch_order = ch_order[: self.max_text_len + 1]
  1507. if len(text_list) == 0:
  1508. return None, None, None
  1509. return text_list, text_node, ch_order.sort()
  1510. def encodech(self, text):
  1511. """ """
  1512. if len(text) == 0 or len(text) > self.max_text_len:
  1513. return None, None, None
  1514. if self.lower:
  1515. text = text.lower()
  1516. text_node_dict = {}
  1517. text_node_dict.update({0: 1})
  1518. character_index = [_ for _ in range(self.num_character)]
  1519. text_list = []
  1520. for char in text:
  1521. if char not in self.dict:
  1522. continue
  1523. i_c = self.dict[char]
  1524. text_list.append(i_c)
  1525. if i_c in text_node_dict.keys():
  1526. text_node_dict[i_c] += 1
  1527. else:
  1528. text_node_dict.update({i_c: 1})
  1529. for ic in list(text_node_dict.keys()):
  1530. character_index.remove(ic)
  1531. none_char_index = sample(character_index, 37 - len(list(text_node_dict.keys())))
  1532. for ic in none_char_index:
  1533. text_node_dict[ic] = 0
  1534. text_node_index = sorted(text_node_dict)
  1535. text_node_num = [text_node_dict[k] for k in text_node_index]
  1536. if len(text_list) == 0:
  1537. return None, None, None
  1538. return text_list, text_node_index, text_node_num
  1539. class LatexOCRLabelEncode(object):
  1540. def __init__(
  1541. self,
  1542. rec_char_dict_path,
  1543. **kwargs,
  1544. ):
  1545. # Set the TOKENIZERS_PARALLELISM environment variable to 'false' to suppress
  1546. # the warning: "The current process just got forked, Disabling parallelism to avoid deadlocks..
  1547. # To disable this warning, please explicitly set TOKENIZERS_PARALLELISM=(true | false)" from tokenizers
  1548. os.environ["TOKENIZERS_PARALLELISM"] = "false"
  1549. from tokenizers import Tokenizer as TokenizerFast
  1550. self.tokenizer = TokenizerFast.from_file(rec_char_dict_path)
  1551. self.model_input_names = ["input_ids", "token_type_ids", "attention_mask"]
  1552. self.pad_token_id = 0
  1553. self.bos_token_id = 1
  1554. self.eos_token_id = 2
  1555. def _convert_encoding(
  1556. self,
  1557. encoding,
  1558. return_token_type_ids=None,
  1559. return_attention_mask=None,
  1560. return_overflowing_tokens=False,
  1561. return_special_tokens_mask=False,
  1562. return_offsets_mapping=False,
  1563. return_length=False,
  1564. verbose=True,
  1565. ):
  1566. if return_token_type_ids is None:
  1567. return_token_type_ids = "token_type_ids" in self.model_input_names
  1568. if return_attention_mask is None:
  1569. return_attention_mask = "attention_mask" in self.model_input_names
  1570. if return_overflowing_tokens and encoding.overflowing is not None:
  1571. encodings = [encoding] + encoding.overflowing
  1572. else:
  1573. encodings = [encoding]
  1574. encoding_dict = defaultdict(list)
  1575. for e in encodings:
  1576. encoding_dict["input_ids"].append(e.ids)
  1577. if return_token_type_ids:
  1578. encoding_dict["token_type_ids"].append(e.type_ids)
  1579. if return_attention_mask:
  1580. encoding_dict["attention_mask"].append(e.attention_mask)
  1581. if return_special_tokens_mask:
  1582. encoding_dict["special_tokens_mask"].append(e.special_tokens_mask)
  1583. if return_offsets_mapping:
  1584. encoding_dict["offset_mapping"].append(e.offsets)
  1585. if return_length:
  1586. encoding_dict["length"].append(len(e.ids))
  1587. return encoding_dict, encodings
  1588. def encode(
  1589. self,
  1590. text,
  1591. text_pair=None,
  1592. return_token_type_ids=False,
  1593. add_special_tokens=True,
  1594. is_split_into_words=False,
  1595. ):
  1596. batched_input = text
  1597. encodings = self.tokenizer.encode_batch(
  1598. batched_input,
  1599. add_special_tokens=add_special_tokens,
  1600. is_pretokenized=is_split_into_words,
  1601. )
  1602. tokens_and_encodings = [
  1603. self._convert_encoding(
  1604. encoding=encoding,
  1605. return_token_type_ids=False,
  1606. return_attention_mask=None,
  1607. return_overflowing_tokens=False,
  1608. return_special_tokens_mask=False,
  1609. return_offsets_mapping=False,
  1610. return_length=False,
  1611. verbose=True,
  1612. )
  1613. for encoding in encodings
  1614. ]
  1615. sanitized_tokens = {}
  1616. for key in tokens_and_encodings[0][0].keys():
  1617. stack = [e for item, _ in tokens_and_encodings for e in item[key]]
  1618. sanitized_tokens[key] = stack
  1619. return sanitized_tokens
  1620. def __call__(self, eqs):
  1621. topk = self.encode(eqs)
  1622. for k, p in zip(topk, [[self.bos_token_id, self.eos_token_id], [1, 1]]):
  1623. process_seq = [[p[0]] + x + [p[1]] for x in topk[k]]
  1624. max_length = 0
  1625. for seq in process_seq:
  1626. max_length = max(max_length, len(seq))
  1627. labels = np.zeros((len(process_seq), max_length), dtype="int64")
  1628. for idx, seq in enumerate(process_seq):
  1629. l = len(seq)
  1630. labels[idx][:l] = seq
  1631. topk[k] = labels
  1632. return (
  1633. np.array(topk["input_ids"]).astype(np.int64),
  1634. np.array(topk["attention_mask"]).astype(np.int64),
  1635. max_length,
  1636. )
  1637. class ExplicitEnum(str, Enum):
  1638. """
  1639. Enum with more explicit error message for missing values.
  1640. """
  1641. @classmethod
  1642. def _missing_(cls, value):
  1643. raise ValueError(
  1644. f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}"
  1645. )
  1646. class TruncationStrategy(ExplicitEnum):
  1647. """
  1648. Possible values for the `truncation` argument in [`PreTrainedTokenizerBase.__call__`]. Useful for tab-completion in
  1649. an IDE.
  1650. """
  1651. ONLY_FIRST = "only_first"
  1652. ONLY_SECOND = "only_second"
  1653. LONGEST_FIRST = "longest_first"
  1654. DO_NOT_TRUNCATE = "do_not_truncate"
  1655. class PaddingStrategy(ExplicitEnum):
  1656. """
  1657. Possible values for the `padding` argument in [`PreTrainedTokenizerBase.__call__`]. Useful for tab-completion in an
  1658. IDE.
  1659. """
  1660. LONGEST = "longest"
  1661. MAX_LENGTH = "max_length"
  1662. DO_NOT_PAD = "do_not_pad"
  1663. class UniMERNetLabelEncode(object):
  1664. SPECIAL_TOKENS_ATTRIBUTES = [
  1665. "bos_token",
  1666. "eos_token",
  1667. "unk_token",
  1668. "sep_token",
  1669. "pad_token",
  1670. "cls_token",
  1671. "mask_token",
  1672. "additional_special_tokens",
  1673. ]
  1674. def __init__(
  1675. self,
  1676. rec_char_dict_path,
  1677. max_seq_len,
  1678. **kwargs,
  1679. ):
  1680. # Set the TOKENIZERS_PARALLELISM environment variable to 'false' to suppress
  1681. # the warning: "The current process just got forked, Disabling parallelism to avoid deadlocks..
  1682. # To disable this warning, please explicitly set TOKENIZERS_PARALLELISM=(true | false)" from tokenizers
  1683. os.environ["TOKENIZERS_PARALLELISM"] = "false"
  1684. from tokenizers import Tokenizer as TokenizerFast
  1685. from tokenizers import AddedToken
  1686. self._unk_token = "<unk>"
  1687. self._bos_token = "<s>"
  1688. self._eos_token = "</s>"
  1689. self._pad_token = "<pad>"
  1690. self._sep_token = None
  1691. self._cls_token = None
  1692. self._mask_token = None
  1693. self._additional_special_tokens = []
  1694. self.model_input_names = ["input_ids", "token_type_ids", "attention_mask"]
  1695. self.max_seq_len = max_seq_len
  1696. self.pad_token_id = 1
  1697. self.bos_token_id = 0
  1698. self.eos_token_id = 2
  1699. self.padding_side = "right"
  1700. self.pad_token = "<pad>"
  1701. self.pad_token_type_id = 0
  1702. self.pad_to_multiple_of = None
  1703. fast_tokenizer_file = os.path.join(rec_char_dict_path, "tokenizer.json")
  1704. tokenizer_config_file = os.path.join(
  1705. rec_char_dict_path, "tokenizer_config.json"
  1706. )
  1707. self.tokenizer = TokenizerFast.from_file(fast_tokenizer_file)
  1708. added_tokens_decoder = {}
  1709. added_tokens_map = {}
  1710. if tokenizer_config_file is not None:
  1711. with open(
  1712. tokenizer_config_file, encoding="utf-8"
  1713. ) as tokenizer_config_handle:
  1714. init_kwargs = json.load(tokenizer_config_handle)
  1715. if "added_tokens_decoder" in init_kwargs:
  1716. for idx, token in init_kwargs["added_tokens_decoder"].items():
  1717. if isinstance(token, dict):
  1718. token = AddedToken(**token)
  1719. if isinstance(token, AddedToken):
  1720. added_tokens_decoder[int(idx)] = token
  1721. added_tokens_map[str(token)] = token
  1722. else:
  1723. raise ValueError(
  1724. f"Found a {token.__class__} in the saved `added_tokens_decoder`, should be a dictionary or an AddedToken instance"
  1725. )
  1726. init_kwargs["added_tokens_decoder"] = added_tokens_decoder
  1727. added_tokens_decoder = init_kwargs.pop("added_tokens_decoder", {})
  1728. tokens_to_add = [
  1729. token
  1730. for index, token in sorted(
  1731. added_tokens_decoder.items(), key=lambda x: x[0]
  1732. )
  1733. if token not in added_tokens_decoder
  1734. ]
  1735. added_tokens_encoder = self.added_tokens_encoder(added_tokens_decoder)
  1736. encoder = list(added_tokens_encoder.keys()) + [
  1737. str(token) for token in tokens_to_add
  1738. ]
  1739. tokens_to_add += [
  1740. token
  1741. for token in self.all_special_tokens_extended
  1742. if token not in encoder and token not in tokens_to_add
  1743. ]
  1744. if len(tokens_to_add) > 0:
  1745. is_last_special = None
  1746. tokens = []
  1747. special_tokens = self.all_special_tokens
  1748. for token in tokens_to_add:
  1749. is_special = (
  1750. (token.special or str(token) in special_tokens)
  1751. if isinstance(token, AddedToken)
  1752. else str(token) in special_tokens
  1753. )
  1754. if is_last_special is None or is_last_special == is_special:
  1755. tokens.append(token)
  1756. else:
  1757. self._add_tokens(tokens, special_tokens=is_last_special)
  1758. tokens = [token]
  1759. is_last_special = is_special
  1760. if tokens:
  1761. self._add_tokens(tokens, special_tokens=is_last_special)
  1762. def _add_tokens(self, new_tokens, special_tokens=False) -> int:
  1763. if special_tokens:
  1764. return self.tokenizer.add_special_tokens(new_tokens)
  1765. return self.tokenizer.add_tokens(new_tokens)
  1766. def added_tokens_encoder(self, added_tokens_decoder):
  1767. return {
  1768. k.content: v
  1769. for v, k in sorted(added_tokens_decoder.items(), key=lambda item: item[0])
  1770. }
  1771. @property
  1772. def all_special_tokens(self):
  1773. all_toks = [str(s) for s in self.all_special_tokens_extended]
  1774. return all_toks
  1775. @property
  1776. def all_special_tokens_extended(self):
  1777. all_tokens = []
  1778. seen = set()
  1779. for value in self.special_tokens_map_extended.values():
  1780. if isinstance(value, (list, tuple)):
  1781. tokens_to_add = [token for token in value if str(token) not in seen]
  1782. else:
  1783. tokens_to_add = [value] if str(value) not in seen else []
  1784. seen.update(map(str, tokens_to_add))
  1785. all_tokens.extend(tokens_to_add)
  1786. return all_tokens
  1787. @property
  1788. def special_tokens_map_extended(self):
  1789. set_attr = {}
  1790. for attr in self.SPECIAL_TOKENS_ATTRIBUTES:
  1791. attr_value = getattr(self, "_" + attr)
  1792. if attr_value:
  1793. set_attr[attr] = attr_value
  1794. return set_attr
  1795. def set_truncation_and_padding(
  1796. self,
  1797. padding_strategy,
  1798. truncation_strategy,
  1799. max_length,
  1800. stride,
  1801. pad_to_multiple_of,
  1802. ):
  1803. _truncation = self.tokenizer.truncation
  1804. _padding = self.tokenizer.padding
  1805. # Set truncation and padding on the backend tokenizer
  1806. if truncation_strategy == TruncationStrategy.DO_NOT_TRUNCATE:
  1807. if _truncation is not None:
  1808. self._tokenizer.no_truncation()
  1809. else:
  1810. target = {
  1811. "max_length": max_length,
  1812. "stride": stride,
  1813. "strategy": truncation_strategy.value,
  1814. "direction": "right",
  1815. }
  1816. if _truncation is None:
  1817. current = None
  1818. else:
  1819. current = {k: _truncation.get(k, None) for k in target}
  1820. if current != target:
  1821. self.tokenizer.enable_truncation(**target)
  1822. if padding_strategy == PaddingStrategy.DO_NOT_PAD:
  1823. if _padding is not None:
  1824. self.tokenizer.no_padding()
  1825. else:
  1826. length = (
  1827. max_length if padding_strategy == PaddingStrategy.MAX_LENGTH else None
  1828. )
  1829. target = {
  1830. "length": length,
  1831. "direction": self.padding_side,
  1832. "pad_id": self.pad_token_id,
  1833. "pad_token": self.pad_token,
  1834. "pad_type_id": self.pad_token_type_id,
  1835. "pad_to_multiple_of": pad_to_multiple_of,
  1836. }
  1837. if _padding != target:
  1838. self.tokenizer.enable_padding(**target)
  1839. def _convert_encoding(
  1840. self,
  1841. encoding,
  1842. return_token_type_ids=None,
  1843. return_attention_mask=None,
  1844. return_overflowing_tokens=False,
  1845. return_special_tokens_mask=False,
  1846. return_offsets_mapping=False,
  1847. return_length=False,
  1848. verbose=True,
  1849. ):
  1850. if return_token_type_ids is None:
  1851. return_token_type_ids = "token_type_ids" in self.model_input_names
  1852. if return_attention_mask is None:
  1853. return_attention_mask = "attention_mask" in self.model_input_names
  1854. if return_overflowing_tokens and encoding.overflowing is not None:
  1855. encodings = [encoding] + encoding.overflowing
  1856. else:
  1857. encodings = [encoding]
  1858. encoding_dict = defaultdict(list)
  1859. for e in encodings:
  1860. encoding_dict["input_ids"].append(e.ids)
  1861. if return_token_type_ids:
  1862. encoding_dict["token_type_ids"].append(e.type_ids)
  1863. if return_attention_mask:
  1864. encoding_dict["attention_mask"].append(e.attention_mask)
  1865. if return_special_tokens_mask:
  1866. encoding_dict["special_tokens_mask"].append(e.special_tokens_mask)
  1867. if return_offsets_mapping:
  1868. encoding_dict["offset_mapping"].append(e.offsets)
  1869. if return_length:
  1870. encoding_dict["length"].append(len(e.ids))
  1871. return encoding_dict, encodings
  1872. def encode(
  1873. self,
  1874. text,
  1875. text_pair=None,
  1876. return_token_type_ids=False,
  1877. add_special_tokens=True,
  1878. is_split_into_words=False,
  1879. ):
  1880. batched_input = text
  1881. self.set_truncation_and_padding(
  1882. padding_strategy=PaddingStrategy.LONGEST,
  1883. truncation_strategy=TruncationStrategy.LONGEST_FIRST,
  1884. max_length=self.max_seq_len,
  1885. stride=0,
  1886. pad_to_multiple_of=None,
  1887. )
  1888. encodings = self.tokenizer.encode_batch(
  1889. batched_input,
  1890. add_special_tokens=add_special_tokens,
  1891. is_pretokenized=is_split_into_words,
  1892. )
  1893. tokens_and_encodings = [
  1894. self._convert_encoding(
  1895. encoding=encoding,
  1896. return_token_type_ids=False,
  1897. return_attention_mask=None,
  1898. return_overflowing_tokens=False,
  1899. return_special_tokens_mask=False,
  1900. return_offsets_mapping=False,
  1901. return_length=False,
  1902. verbose=True,
  1903. )
  1904. for encoding in encodings
  1905. ]
  1906. sanitized_tokens = {}
  1907. for key in tokens_and_encodings[0][0].keys():
  1908. stack = [e for item, _ in tokens_and_encodings for e in item[key]]
  1909. sanitized_tokens[key] = stack
  1910. return sanitized_tokens
  1911. def __call__(self, data):
  1912. eqs = data["label"]
  1913. topk = self.encode([eqs])
  1914. for k, p in zip(topk, [[self.bos_token_id, self.eos_token_id], [1, 1]]):
  1915. process_seq = [x for x in topk[k]]
  1916. max_length = 0
  1917. for seq in process_seq:
  1918. max_length = max(max_length, len(seq))
  1919. data["label"] = np.array(topk["input_ids"]).astype(np.int64)[0]
  1920. data["attention_mask"] = np.array(topk["attention_mask"]).astype(np.int64)[0]
  1921. return data