convert_slow_tokenizer.py 67 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873
  1. # Copyright 2018 The HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. Utilities to convert slow tokenizers in their fast tokenizers counterparts.
  16. All the conversions are grouped here to gather SentencePiece dependencies outside of the fast tokenizers files and
  17. allow to make our dependency on SentencePiece optional.
  18. """
  19. import warnings
  20. from functools import lru_cache
  21. from typing import Optional
  22. from packaging import version
  23. from tokenizers import AddedToken, Regex, Tokenizer, decoders, normalizers, pre_tokenizers, processors
  24. from tokenizers.models import BPE, Unigram, WordPiece
  25. from tqdm import tqdm
  26. from .utils import is_protobuf_available, is_sentencepiece_available, logging, requires_backends
  27. from .utils.import_utils import PROTOBUF_IMPORT_ERROR
  28. logger = logging.get_logger(__name__)
  29. def import_protobuf(error_message=""):
  30. if is_sentencepiece_available():
  31. from sentencepiece import sentencepiece_model_pb2
  32. return sentencepiece_model_pb2
  33. if is_protobuf_available():
  34. import google.protobuf
  35. if version.parse(google.protobuf.__version__) < version.parse("4.0.0"):
  36. from transformers.utils import sentencepiece_model_pb2
  37. else:
  38. from transformers.utils import sentencepiece_model_pb2_new as sentencepiece_model_pb2
  39. return sentencepiece_model_pb2
  40. else:
  41. raise ImportError(PROTOBUF_IMPORT_ERROR.format(error_message))
  42. def _get_prepend_scheme(add_prefix_space: bool, original_tokenizer) -> str:
  43. if add_prefix_space:
  44. prepend_scheme = "always"
  45. if not getattr(original_tokenizer, "legacy", True):
  46. prepend_scheme = "first"
  47. else:
  48. prepend_scheme = "never"
  49. return prepend_scheme
  50. def generate_merges(vocab, vocab_scores):
  51. reverse = vocab_scores is not None
  52. vocab_scores = dict(vocab_scores) if reverse else vocab
  53. merges = []
  54. for merge, piece_score in vocab_scores.items():
  55. local = []
  56. for index in range(1, len(merge)):
  57. piece_l, piece_r = merge[:index], merge[index:]
  58. if piece_l in vocab and piece_r in vocab:
  59. local.append((piece_l, piece_r, piece_score))
  60. local = sorted(local, key=lambda x: (vocab[x[0]], vocab[x[1]]))
  61. merges.extend(local)
  62. merges = sorted(merges, key=lambda val: (val[2], len(val[0]), len(val[1])), reverse=reverse)
  63. merges = [(val[0], val[1]) for val in merges]
  64. return merges
  65. class SentencePieceExtractor:
  66. """
  67. Extractor implementation for SentencePiece trained models. https://github.com/google/sentencepiece
  68. """
  69. def __init__(self, model: str):
  70. requires_backends(self, "sentencepiece")
  71. from sentencepiece import SentencePieceProcessor
  72. self.sp = SentencePieceProcessor()
  73. self.sp.Load(model)
  74. def extract(self, vocab_scores=None) -> tuple[dict[str, int], list[tuple]]:
  75. """
  76. By default will return vocab and merges with respect to their order, by sending `vocab_scores` we're going to
  77. order the merges with respect to the piece scores instead.
  78. """
  79. sp = self.sp
  80. vocab = {sp.id_to_piece(index): index for index in range(sp.GetPieceSize())}
  81. merges = generate_merges(vocab, vocab_scores)
  82. return vocab, merges
  83. class GemmaSentencePieceExtractor(SentencePieceExtractor):
  84. def extract(self, vocab_scores=None) -> tuple[dict[str, int], list[tuple]]:
  85. """
  86. By default will return vocab and merges with respect to their order, by sending `vocab_scores` we're going to
  87. order the merges with respect to the piece scores instead.
  88. """
  89. sp = self.sp
  90. vocab = {sp.id_to_piece(index): index for index in range(sp.GetPieceSize())}
  91. # If "\t" is missing in the vocab, we have to do this to support merges
  92. # "<0x09>" is the bytefallback for `\t`
  93. if "\t" not in vocab:
  94. vocab["\t"] = vocab.get("<0x09>")
  95. merges = generate_merges(vocab, vocab_scores)
  96. return vocab, merges
  97. def check_number_comma(piece: str) -> bool:
  98. return len(piece) < 2 or piece[-1] != "," or not piece[-2].isdigit()
  99. class Converter:
  100. def __init__(self, original_tokenizer):
  101. self.original_tokenizer = original_tokenizer
  102. def converted(self) -> Tokenizer:
  103. raise NotImplementedError()
  104. class BertConverter(Converter):
  105. def converted(self) -> Tokenizer:
  106. vocab = self.original_tokenizer.vocab
  107. tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
  108. tokenize_chinese_chars = False
  109. strip_accents = False
  110. do_lower_case = False
  111. if hasattr(self.original_tokenizer, "basic_tokenizer"):
  112. tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
  113. strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
  114. do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
  115. tokenizer.normalizer = normalizers.BertNormalizer(
  116. clean_text=True,
  117. handle_chinese_chars=tokenize_chinese_chars,
  118. strip_accents=strip_accents,
  119. lowercase=do_lower_case,
  120. )
  121. tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
  122. cls = str(self.original_tokenizer.cls_token)
  123. sep = str(self.original_tokenizer.sep_token)
  124. cls_token_id = self.original_tokenizer.cls_token_id
  125. sep_token_id = self.original_tokenizer.sep_token_id
  126. tokenizer.post_processor = processors.TemplateProcessing(
  127. single=f"{cls}:0 $A:0 {sep}:0",
  128. pair=f"{cls}:0 $A:0 {sep}:0 $B:1 {sep}:1",
  129. special_tokens=[
  130. (cls, cls_token_id),
  131. (sep, sep_token_id),
  132. ],
  133. )
  134. tokenizer.decoder = decoders.WordPiece(prefix="##")
  135. return tokenizer
  136. class SplinterConverter(Converter):
  137. def converted(self) -> Tokenizer:
  138. vocab = self.original_tokenizer.vocab
  139. tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
  140. tokenize_chinese_chars = False
  141. strip_accents = False
  142. do_lower_case = False
  143. if hasattr(self.original_tokenizer, "basic_tokenizer"):
  144. tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
  145. strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
  146. do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
  147. tokenizer.normalizer = normalizers.BertNormalizer(
  148. clean_text=True,
  149. handle_chinese_chars=tokenize_chinese_chars,
  150. strip_accents=strip_accents,
  151. lowercase=do_lower_case,
  152. )
  153. tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
  154. cls = str(self.original_tokenizer.cls_token)
  155. sep = str(self.original_tokenizer.sep_token)
  156. question = str(self.original_tokenizer.question_token)
  157. dot = "."
  158. cls_token_id = self.original_tokenizer.cls_token_id
  159. sep_token_id = self.original_tokenizer.sep_token_id
  160. question_token_id = self.original_tokenizer.question_token_id
  161. dot_token_id = self.original_tokenizer.convert_tokens_to_ids(".")
  162. if self.original_tokenizer.padding_side == "right":
  163. pair = f"{cls}:0 $A:0 {question} {dot} {sep}:0 $B:1 {sep}:1"
  164. else:
  165. pair = f"{cls}:0 $A:0 {sep}:0 $B:1 {question} {dot} {sep}:1"
  166. tokenizer.post_processor = processors.TemplateProcessing(
  167. single=f"{cls}:0 $A:0 {sep}:0",
  168. pair=pair,
  169. special_tokens=[
  170. (cls, cls_token_id),
  171. (sep, sep_token_id),
  172. (question, question_token_id),
  173. (dot, dot_token_id),
  174. ],
  175. )
  176. tokenizer.decoder = decoders.WordPiece(prefix="##")
  177. return tokenizer
  178. class FunnelConverter(Converter):
  179. def converted(self) -> Tokenizer:
  180. vocab = self.original_tokenizer.vocab
  181. tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
  182. tokenize_chinese_chars = False
  183. strip_accents = False
  184. do_lower_case = False
  185. if hasattr(self.original_tokenizer, "basic_tokenizer"):
  186. tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
  187. strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
  188. do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
  189. tokenizer.normalizer = normalizers.BertNormalizer(
  190. clean_text=True,
  191. handle_chinese_chars=tokenize_chinese_chars,
  192. strip_accents=strip_accents,
  193. lowercase=do_lower_case,
  194. )
  195. tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
  196. cls = str(self.original_tokenizer.cls_token)
  197. sep = str(self.original_tokenizer.sep_token)
  198. cls_token_id = self.original_tokenizer.cls_token_id
  199. sep_token_id = self.original_tokenizer.sep_token_id
  200. tokenizer.post_processor = processors.TemplateProcessing(
  201. single=f"{cls}:2 $A:0 {sep}:0", # token_type_id is 2 for Funnel transformer
  202. pair=f"{cls}:2 $A:0 {sep}:0 $B:1 {sep}:1",
  203. special_tokens=[
  204. (cls, cls_token_id),
  205. (sep, sep_token_id),
  206. ],
  207. )
  208. tokenizer.decoder = decoders.WordPiece(prefix="##")
  209. return tokenizer
  210. class MPNetConverter(Converter):
  211. def converted(self) -> Tokenizer:
  212. vocab = self.original_tokenizer.vocab
  213. tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
  214. tokenize_chinese_chars = False
  215. strip_accents = False
  216. do_lower_case = False
  217. if hasattr(self.original_tokenizer, "basic_tokenizer"):
  218. tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
  219. strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
  220. do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
  221. tokenizer.normalizer = normalizers.BertNormalizer(
  222. clean_text=True,
  223. handle_chinese_chars=tokenize_chinese_chars,
  224. strip_accents=strip_accents,
  225. lowercase=do_lower_case,
  226. )
  227. tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
  228. cls = str(self.original_tokenizer.cls_token)
  229. sep = str(self.original_tokenizer.sep_token)
  230. cls_token_id = self.original_tokenizer.cls_token_id
  231. sep_token_id = self.original_tokenizer.sep_token_id
  232. tokenizer.post_processor = processors.TemplateProcessing(
  233. single=f"{cls}:0 $A:0 {sep}:0",
  234. pair=f"{cls}:0 $A:0 {sep}:0 {sep}:0 $B:1 {sep}:1", # MPNet uses two [SEP] tokens
  235. special_tokens=[
  236. (cls, cls_token_id),
  237. (sep, sep_token_id),
  238. ],
  239. )
  240. tokenizer.decoder = decoders.WordPiece(prefix="##")
  241. return tokenizer
  242. class OpenAIGPTConverter(Converter):
  243. def converted(self) -> Tokenizer:
  244. vocab = self.original_tokenizer.encoder
  245. merges = list(self.original_tokenizer.bpe_ranks.keys())
  246. unk_token = self.original_tokenizer.unk_token
  247. tokenizer = Tokenizer(
  248. BPE(
  249. vocab=vocab,
  250. merges=merges,
  251. dropout=None,
  252. unk_token=str(unk_token),
  253. end_of_word_suffix="</w>",
  254. fuse_unk=False,
  255. )
  256. )
  257. if tokenizer.token_to_id(str(unk_token)) is not None:
  258. tokenizer.add_special_tokens([str(unk_token)])
  259. tokenizer.normalizer = normalizers.BertNormalizer(lowercase=True)
  260. tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
  261. tokenizer.decoder = decoders.BPEDecoder(suffix="</w>")
  262. return tokenizer
  263. class GPT2Converter(Converter):
  264. def converted(
  265. self, vocab: Optional[dict[str, int]] = None, merges: Optional[list[tuple[str, str]]] = None
  266. ) -> Tokenizer:
  267. if not vocab:
  268. vocab = self.original_tokenizer.encoder
  269. if not merges:
  270. merges = list(self.original_tokenizer.bpe_ranks)
  271. tokenizer = Tokenizer(
  272. BPE(
  273. vocab=vocab,
  274. merges=merges,
  275. dropout=None,
  276. continuing_subword_prefix="",
  277. end_of_word_suffix="",
  278. fuse_unk=False,
  279. )
  280. )
  281. add_prefix_space = getattr(self.original_tokenizer, "add_prefix_space", False)
  282. tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=add_prefix_space)
  283. tokenizer.decoder = decoders.ByteLevel()
  284. if getattr(self.original_tokenizer, "add_bos_token", False):
  285. bos = self.original_tokenizer.bos_token
  286. bos_token_id = self.original_tokenizer.bos_token_id
  287. tokenizer.post_processor = processors.TemplateProcessing(
  288. single=f"{bos}:0 $A:0",
  289. pair=f"{bos}:0 $A:0 $B:1",
  290. special_tokens=[
  291. (bos, bos_token_id),
  292. ],
  293. )
  294. else:
  295. # XXX trim_offsets=False actually means this post_processor doesn't
  296. # really do anything.
  297. tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
  298. return tokenizer
  299. class HerbertConverter(Converter):
  300. def converted(self) -> Tokenizer:
  301. tokenizer_info_str = "#version:"
  302. token_suffix = "</w>"
  303. vocab = self.original_tokenizer.encoder
  304. merges = list(self.original_tokenizer.bpe_ranks.keys())
  305. if tokenizer_info_str in merges[0][0]:
  306. merges = merges[1:]
  307. tokenizer = Tokenizer(
  308. BPE(
  309. vocab,
  310. merges,
  311. dropout=None,
  312. unk_token=self.original_tokenizer.unk_token,
  313. end_of_word_suffix=token_suffix,
  314. )
  315. )
  316. tokenizer.normalizer = normalizers.BertNormalizer(lowercase=False, strip_accents=False)
  317. tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
  318. tokenizer.decoder = decoders.BPEDecoder(suffix=token_suffix)
  319. tokenizer.post_processor = processors.BertProcessing(
  320. sep=(self.original_tokenizer.sep_token, self.original_tokenizer.sep_token_id),
  321. cls=(self.original_tokenizer.cls_token, self.original_tokenizer.cls_token_id),
  322. )
  323. return tokenizer
  324. class Qwen2Converter(Converter):
  325. def converted(
  326. self, vocab: Optional[dict[str, int]] = None, merges: Optional[list[tuple[str, str]]] = None
  327. ) -> Tokenizer:
  328. if not vocab:
  329. vocab = self.original_tokenizer.encoder
  330. if not merges:
  331. merges = list(self.original_tokenizer.bpe_ranks.keys())
  332. tokenizer = Tokenizer(
  333. BPE(
  334. vocab=vocab,
  335. merges=merges,
  336. dropout=None,
  337. unk_token=None,
  338. continuing_subword_prefix="",
  339. end_of_word_suffix="",
  340. fuse_unk=False,
  341. byte_fallback=False,
  342. )
  343. )
  344. tokenizer.normalizer = normalizers.NFC()
  345. tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
  346. [
  347. pre_tokenizers.Split(
  348. Regex(
  349. r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
  350. ),
  351. behavior="isolated",
  352. invert=False,
  353. ),
  354. pre_tokenizers.ByteLevel(
  355. add_prefix_space=getattr(self.original_tokenizer, "add_prefix_space", False),
  356. use_regex=False,
  357. ),
  358. ]
  359. )
  360. tokenizer.decoder = decoders.ByteLevel()
  361. tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
  362. return tokenizer
  363. class RobertaConverter(Converter):
  364. def converted(self) -> Tokenizer:
  365. ot = self.original_tokenizer
  366. vocab = ot.encoder
  367. merges = list(ot.bpe_ranks.keys())
  368. tokenizer = Tokenizer(
  369. BPE(
  370. vocab=vocab,
  371. merges=merges,
  372. dropout=None,
  373. continuing_subword_prefix="",
  374. end_of_word_suffix="",
  375. fuse_unk=False,
  376. )
  377. )
  378. tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=ot.add_prefix_space)
  379. tokenizer.decoder = decoders.ByteLevel()
  380. tokenizer.post_processor = processors.RobertaProcessing(
  381. sep=(ot.sep_token, ot.sep_token_id),
  382. cls=(ot.cls_token, ot.cls_token_id),
  383. add_prefix_space=ot.add_prefix_space,
  384. trim_offsets=True, # True by default on Roberta (historical)
  385. )
  386. return tokenizer
  387. class RoFormerConverter(Converter):
  388. def converted(self) -> Tokenizer:
  389. from .models.roformer.tokenization_utils import JiebaPreTokenizer
  390. vocab = self.original_tokenizer.vocab
  391. tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
  392. strip_accents = False
  393. do_lower_case = False
  394. if hasattr(self.original_tokenizer, "basic_tokenizer"):
  395. strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
  396. do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
  397. tokenizer.normalizer = normalizers.BertNormalizer(
  398. clean_text=True,
  399. handle_chinese_chars=False,
  400. strip_accents=strip_accents,
  401. lowercase=do_lower_case,
  402. )
  403. tokenizer.pre_tokenizer = pre_tokenizers.PreTokenizer.custom(JiebaPreTokenizer(vocab))
  404. cls = str(self.original_tokenizer.cls_token)
  405. sep = str(self.original_tokenizer.sep_token)
  406. cls_token_id = self.original_tokenizer.cls_token_id
  407. sep_token_id = self.original_tokenizer.sep_token_id
  408. tokenizer.post_processor = processors.TemplateProcessing(
  409. single=f"{cls}:0 $A:0 {sep}:0",
  410. pair=f"{cls}:0 $A:0 {sep}:0 $B:1 {sep}:1",
  411. special_tokens=[
  412. (cls, cls_token_id),
  413. (sep, sep_token_id),
  414. ],
  415. )
  416. tokenizer.decoder = decoders.WordPiece(prefix="##")
  417. return tokenizer
  418. class DebertaConverter(Converter):
  419. def converted(self) -> Tokenizer:
  420. ot = self.original_tokenizer
  421. vocab = ot.encoder
  422. merges = list(ot.bpe_ranks.keys())
  423. tokenizer = Tokenizer(
  424. BPE(
  425. vocab=vocab,
  426. merges=merges,
  427. dropout=None,
  428. continuing_subword_prefix="",
  429. end_of_word_suffix="",
  430. fuse_unk=False,
  431. )
  432. )
  433. tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=ot.add_prefix_space)
  434. tokenizer.decoder = decoders.ByteLevel()
  435. tokenizer.post_processor = processors.TemplateProcessing(
  436. single="[CLS]:0 $A:0 [SEP]:0",
  437. pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
  438. special_tokens=[
  439. ("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
  440. ("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
  441. ],
  442. )
  443. return tokenizer
  444. class SpmConverter(Converter):
  445. handle_byte_fallback = False
  446. SpmExtractor = SentencePieceExtractor
  447. special_tokens = {}
  448. def __init__(self, *args):
  449. requires_backends(self, "protobuf")
  450. super().__init__(*args)
  451. # from .utils import sentencepiece_model_pb2 as model_pb2
  452. model_pb2 = import_protobuf()
  453. m = model_pb2.ModelProto()
  454. with open(self.original_tokenizer.vocab_file, "rb") as f:
  455. m.ParseFromString(f.read())
  456. self.proto = m
  457. if self.proto.trainer_spec.byte_fallback and not self.handle_byte_fallback:
  458. warnings.warn(
  459. "The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option"
  460. " which is not implemented in the fast tokenizers. In practice this means that the fast version of the"
  461. " tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these "
  462. "unknown tokens into a sequence of byte tokens matching the original piece of text."
  463. )
  464. def vocab(self, proto):
  465. return [(piece.piece, piece.score) for piece in proto.pieces]
  466. def unk_id(self, proto):
  467. return proto.trainer_spec.unk_id
  468. def tokenizer(self, proto):
  469. model_type = proto.trainer_spec.model_type
  470. vocab_scores = self.vocab(proto)
  471. if model_type == 1:
  472. tokenizer = Tokenizer(
  473. Unigram(
  474. vocab_scores,
  475. unk_id=self.unk_id(proto),
  476. byte_fallback=self.handle_byte_fallback,
  477. )
  478. )
  479. elif model_type == 2:
  480. _, merges = self.SpmExtractor(self.original_tokenizer.vocab_file).extract(vocab_scores)
  481. bpe_vocab = {word: i for i, (word, score) in enumerate(vocab_scores)}
  482. tokenizer = Tokenizer(
  483. BPE(
  484. bpe_vocab,
  485. merges,
  486. unk_token=proto.trainer_spec.unk_piece,
  487. fuse_unk=True,
  488. byte_fallback=self.handle_byte_fallback,
  489. dropout=None,
  490. )
  491. )
  492. else:
  493. raise Exception(
  494. "You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
  495. )
  496. # control tokens are special
  497. # user defined symbols are not
  498. # both user and control tokens are AddedTokens
  499. # Add user defined symbols (type == 4) from sentencepiece (https://github.com/google/sentencepiece/blob/6225e08edb2577757163b3f5dbba4c0b670ef445/src/sentencepiece_model.proto#L299C29-L299C33)
  500. spm_added_tokens = [
  501. (id, p.piece, p.type == 3 or p.piece in self.special_tokens)
  502. for id, p in enumerate(proto.pieces)
  503. if p.type in [3, 4]
  504. ]
  505. tokenizer.add_tokens(
  506. [
  507. AddedToken(token, normalized=False, special=special)
  508. for id, token, special in sorted(spm_added_tokens, key=lambda x: x[0])
  509. ]
  510. )
  511. return tokenizer
  512. def normalizer(self, proto):
  513. precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
  514. _normalizers = [
  515. normalizers.Strip(left=False, right=True), # stripping is important
  516. normalizers.Replace(Regex(" {2,}"), "▁"),
  517. ]
  518. if not precompiled_charsmap:
  519. return normalizers.Sequence(_normalizers)
  520. else:
  521. return normalizers.Sequence([normalizers.Precompiled(precompiled_charsmap)] + _normalizers)
  522. def pre_tokenizer(self, replacement, add_prefix_space):
  523. prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
  524. return pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme)
  525. def post_processor(self):
  526. return None
  527. def decoder(self, replacement, add_prefix_space):
  528. prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
  529. return decoders.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme)
  530. def converted(self) -> Tokenizer:
  531. tokenizer = self.tokenizer(self.proto)
  532. # Tokenizer assemble
  533. normalizer = self.normalizer(self.proto)
  534. if normalizer is not None:
  535. tokenizer.normalizer = normalizer
  536. replacement = "▁"
  537. add_prefix_space = True
  538. if hasattr(self.original_tokenizer, "add_prefix_space"):
  539. add_prefix_space = self.original_tokenizer.add_prefix_space
  540. pre_tokenizer = self.pre_tokenizer(replacement, add_prefix_space)
  541. if pre_tokenizer is not None:
  542. tokenizer.pre_tokenizer = pre_tokenizer
  543. tokenizer.decoder = self.decoder(replacement, add_prefix_space)
  544. post_processor = self.post_processor()
  545. if post_processor:
  546. tokenizer.post_processor = post_processor
  547. return tokenizer
  548. class AlbertConverter(SpmConverter):
  549. def vocab(self, proto):
  550. return [
  551. (piece.piece, piece.score) if check_number_comma(piece.piece) else (piece.piece, piece.score - 100)
  552. for piece in proto.pieces
  553. ]
  554. def normalizer(self, proto):
  555. list_normalizers = [
  556. normalizers.Replace("``", '"'),
  557. normalizers.Replace("''", '"'),
  558. ]
  559. if not self.original_tokenizer.keep_accents:
  560. list_normalizers.append(normalizers.NFKD())
  561. list_normalizers.append(normalizers.StripAccents())
  562. if self.original_tokenizer.do_lower_case:
  563. list_normalizers.append(normalizers.Lowercase())
  564. precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
  565. if precompiled_charsmap:
  566. list_normalizers.append(normalizers.Precompiled(precompiled_charsmap))
  567. list_normalizers.append(normalizers.Replace(Regex(" {2,}"), " "))
  568. return normalizers.Sequence(list_normalizers)
  569. def post_processor(self):
  570. return processors.TemplateProcessing(
  571. single="[CLS]:0 $A:0 [SEP]:0",
  572. pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
  573. special_tokens=[
  574. ("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
  575. ("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
  576. ],
  577. )
  578. class BarthezConverter(SpmConverter):
  579. def unk_id(self, proto):
  580. unk_id = 3
  581. return unk_id
  582. def post_processor(self):
  583. return processors.TemplateProcessing(
  584. single="<s> $A </s>",
  585. pair="<s> $A </s> </s> $B </s>",
  586. special_tokens=[
  587. ("<s>", self.original_tokenizer.convert_tokens_to_ids("<s>")),
  588. ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
  589. ],
  590. )
  591. class CamembertConverter(SpmConverter):
  592. def vocab(self, proto):
  593. vocab = [
  594. ("<s>NOTUSED", 0.0),
  595. ("<pad>", 0.0),
  596. ("</s>NOTUSED", 0.0),
  597. ("<unk>", 0.0),
  598. ("<unk>NOTUSED", -100),
  599. ]
  600. # We down-grade the original SentencePiece by -100 to avoid using it and use our added token instead
  601. vocab += [(piece.piece, piece.score) for piece in proto.pieces[1:]]
  602. vocab += [("<mask>", 0.0)]
  603. return vocab
  604. def unk_id(self, proto):
  605. # See vocab unk position
  606. return 3
  607. def post_processor(self):
  608. return processors.TemplateProcessing(
  609. single="<s> $A </s>",
  610. pair="<s> $A </s> </s> $B </s>",
  611. special_tokens=[
  612. ("<s>", self.original_tokenizer.convert_tokens_to_ids("<s>")),
  613. ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
  614. ],
  615. )
  616. class DebertaV2Converter(SpmConverter):
  617. def pre_tokenizer(self, replacement, add_prefix_space):
  618. list_pretokenizers = []
  619. if self.original_tokenizer.split_by_punct:
  620. list_pretokenizers.append(pre_tokenizers.Punctuation(behavior="isolated"))
  621. prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
  622. list_pretokenizers.append(pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme))
  623. return pre_tokenizers.Sequence(list_pretokenizers)
  624. def normalizer(self, proto):
  625. list_normalizers = []
  626. if self.original_tokenizer.do_lower_case:
  627. list_normalizers.append(normalizers.Lowercase())
  628. list_normalizers.append(normalizers.Strip())
  629. precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
  630. if precompiled_charsmap:
  631. list_normalizers.append(normalizers.Precompiled(precompiled_charsmap))
  632. list_normalizers.append(normalizers.Replace(Regex(" {2,}"), " "))
  633. return normalizers.Sequence(list_normalizers)
  634. def post_processor(self):
  635. return processors.TemplateProcessing(
  636. single="[CLS]:0 $A:0 [SEP]:0",
  637. pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
  638. special_tokens=[
  639. ("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
  640. ("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
  641. ],
  642. )
  643. class MBartConverter(SpmConverter):
  644. def vocab(self, proto):
  645. vocab = [
  646. ("<s>", 0.0),
  647. ("<pad>", 0.0),
  648. ("</s>", 0.0),
  649. ("<unk>", 0.0),
  650. ]
  651. vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
  652. vocab += [
  653. ("ar_AR", 0.0),
  654. ("cs_CZ", 0.0),
  655. ("de_DE", 0.0),
  656. ("en_XX", 0.0),
  657. ("es_XX", 0.0),
  658. ("et_EE", 0.0),
  659. ("fi_FI", 0.0),
  660. ("fr_XX", 0.0),
  661. ("gu_IN", 0.0),
  662. ("hi_IN", 0.0),
  663. ("it_IT", 0.0),
  664. ("ja_XX", 0.0),
  665. ("kk_KZ", 0.0),
  666. ("ko_KR", 0.0),
  667. ("lt_LT", 0.0),
  668. ("lv_LV", 0.0),
  669. ("my_MM", 0.0),
  670. ("ne_NP", 0.0),
  671. ("nl_XX", 0.0),
  672. ("ro_RO", 0.0),
  673. ("ru_RU", 0.0),
  674. ("si_LK", 0.0),
  675. ("tr_TR", 0.0),
  676. ("vi_VN", 0.0),
  677. ("zh_CN", 0.0),
  678. ]
  679. vocab += [("<mask>", 0.0)]
  680. return vocab
  681. def unk_id(self, proto):
  682. return 3
  683. def post_processor(self):
  684. return processors.TemplateProcessing(
  685. single="$A </s> en_XX",
  686. pair="$A $B </s> en_XX",
  687. special_tokens=[
  688. ("en_XX", self.original_tokenizer.convert_tokens_to_ids("en_XX")),
  689. ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
  690. ],
  691. )
  692. class MBart50Converter(SpmConverter):
  693. def vocab(self, proto):
  694. vocab = [
  695. ("<s>", 0.0),
  696. ("<pad>", 0.0),
  697. ("</s>", 0.0),
  698. ("<unk>", 0.0),
  699. ]
  700. vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
  701. vocab += [("ar_AR", 0.0), ("cs_CZ", 0.0), ("de_DE", 0.0), ("en_XX", 0.0), ("es_XX", 0.0), ("et_EE", 0.0), ("fi_FI", 0.0), ("fr_XX", 0.0), ("gu_IN", 0.0), ("hi_IN", 0.0), ("it_IT", 0.0), ("ja_XX", 0.0), ("kk_KZ", 0.0), ("ko_KR", 0.0), ("lt_LT", 0.0), ("lv_LV", 0.0), ("my_MM", 0.0), ("ne_NP", 0.0), ("nl_XX", 0.0), ("ro_RO", 0.0), ("ru_RU", 0.0), ("si_LK", 0.0), ("tr_TR", 0.0), ("vi_VN", 0.0), ("zh_CN", 0.0), ("af_ZA", 0.0), ("az_AZ", 0.0), ("bn_IN", 0.0), ("fa_IR", 0.0), ("he_IL", 0.0), ("hr_HR", 0.0), ("id_ID", 0.0), ("ka_GE", 0.0), ("km_KH", 0.0), ("mk_MK", 0.0), ("ml_IN", 0.0), ("mn_MN", 0.0), ("mr_IN", 0.0), ("pl_PL", 0.0), ("ps_AF", 0.0), ("pt_XX", 0.0), ("sv_SE", 0.0), ("sw_KE", 0.0), ("ta_IN", 0.0), ("te_IN", 0.0), ("th_TH", 0.0), ("tl_XX", 0.0), ("uk_UA", 0.0), ("ur_PK", 0.0), ("xh_ZA", 0.0), ("gl_ES", 0.0), ("sl_SI", 0.0)] # fmt: skip
  702. vocab += [("<mask>", 0.0)]
  703. return vocab
  704. def unk_id(self, proto):
  705. return 3
  706. def post_processor(self):
  707. return processors.TemplateProcessing(
  708. single="en_XX $A </s>",
  709. pair="en_XX $A $B </s>",
  710. special_tokens=[
  711. ("en_XX", self.original_tokenizer.convert_tokens_to_ids("en_XX")),
  712. ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
  713. ],
  714. )
  715. class NllbConverter(SpmConverter):
  716. def vocab(self, proto):
  717. vocab = [
  718. ("<s>", 0.0),
  719. ("<pad>", 0.0),
  720. ("</s>", 0.0),
  721. ("<unk>", 0.0),
  722. ]
  723. vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
  724. return vocab
  725. def unk_id(self, proto):
  726. return 3
  727. def post_processor(self):
  728. return processors.TemplateProcessing(
  729. single="eng_Latn $A </s>",
  730. pair="eng_Latn $A $B </s>",
  731. special_tokens=[
  732. ("eng_Latn", self.original_tokenizer.convert_tokens_to_ids("eng_Latn")),
  733. ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
  734. ],
  735. )
  736. class SeamlessM4TConverter(SpmConverter):
  737. def vocab(self, proto):
  738. vocab = [
  739. ("<pad>", 0.0),
  740. ("<unk>", 0.0),
  741. ("<s>", 0.0),
  742. ("</s>", 0.0),
  743. ]
  744. vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
  745. return vocab
  746. def unk_id(self, proto):
  747. return self.original_tokenizer.unk_token_id
  748. def post_processor(self):
  749. return processors.TemplateProcessing(
  750. single="__eng__ $A </s>",
  751. pair="__eng__ $A $B </s>",
  752. special_tokens=[
  753. ("__eng__", self.original_tokenizer.convert_tokens_to_ids("__eng__")),
  754. ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
  755. ],
  756. )
  757. class XLMRobertaConverter(SpmConverter):
  758. def vocab(self, proto):
  759. vocab = [
  760. ("<s>", 0.0),
  761. ("<pad>", 0.0),
  762. ("</s>", 0.0),
  763. ("<unk>", 0.0),
  764. ]
  765. vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
  766. vocab += [("<mask>", 0.0)]
  767. return vocab
  768. def unk_id(self, proto):
  769. unk_id = 3
  770. return unk_id
  771. def post_processor(self):
  772. return processors.TemplateProcessing(
  773. single="<s> $A </s>",
  774. pair="<s> $A </s> </s> $B </s>",
  775. special_tokens=[
  776. ("<s>", self.original_tokenizer.convert_tokens_to_ids("<s>")),
  777. ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
  778. ],
  779. )
  780. class XLNetConverter(SpmConverter):
  781. def vocab(self, proto):
  782. return [
  783. (piece.piece, piece.score) if check_number_comma(piece.piece) else (piece.piece, piece.score - 100)
  784. for piece in proto.pieces
  785. ]
  786. def normalizer(self, proto):
  787. list_normalizers = [
  788. normalizers.Replace("``", '"'),
  789. normalizers.Replace("''", '"'),
  790. ]
  791. if not self.original_tokenizer.keep_accents:
  792. list_normalizers.append(normalizers.NFKD())
  793. list_normalizers.append(normalizers.StripAccents())
  794. if self.original_tokenizer.do_lower_case:
  795. list_normalizers.append(normalizers.Lowercase())
  796. precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
  797. if precompiled_charsmap:
  798. list_normalizers.append(normalizers.Precompiled(precompiled_charsmap))
  799. list_normalizers.append(normalizers.Replace(Regex(" {2,}"), " "))
  800. return normalizers.Sequence(list_normalizers)
  801. def post_processor(self):
  802. return processors.TemplateProcessing(
  803. single="$A:0 <sep>:0 <cls>:2",
  804. pair="$A:0 <sep>:0 $B:1 <sep>:1 <cls>:2",
  805. special_tokens=[
  806. ("<sep>", self.original_tokenizer.convert_tokens_to_ids("<sep>")),
  807. ("<cls>", self.original_tokenizer.convert_tokens_to_ids("<cls>")),
  808. ],
  809. )
  810. class ReformerConverter(SpmConverter):
  811. pass
  812. class RemBertConverter(SpmConverter):
  813. # Inspired from AlbertConverter
  814. def normalizer(self, proto):
  815. list_normalizers = [
  816. normalizers.Replace("``", '"'),
  817. normalizers.Replace("''", '"'),
  818. normalizers.Replace(Regex(" {2,}"), " "),
  819. ]
  820. if not self.original_tokenizer.keep_accents:
  821. list_normalizers.append(normalizers.NFKD())
  822. list_normalizers.append(normalizers.StripAccents())
  823. if self.original_tokenizer.do_lower_case:
  824. list_normalizers.append(normalizers.Lowercase())
  825. precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
  826. if precompiled_charsmap:
  827. list_normalizers.append(normalizers.Precompiled(precompiled_charsmap))
  828. return normalizers.Sequence(list_normalizers)
  829. def post_processor(self):
  830. return processors.TemplateProcessing(
  831. single="[CLS]:0 $A:0 [SEP]:0",
  832. pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
  833. special_tokens=[
  834. ("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
  835. ("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
  836. ],
  837. )
  838. class BertGenerationConverter(SpmConverter):
  839. pass
  840. class PegasusConverter(SpmConverter):
  841. def vocab(self, proto):
  842. vocab = [
  843. (self.original_tokenizer.pad_token, 0.0),
  844. (self.original_tokenizer.eos_token, 0.0),
  845. ]
  846. if self.original_tokenizer.mask_token_sent is not None:
  847. vocab += [(self.original_tokenizer.mask_token_sent, 0.0)]
  848. if (
  849. self.original_tokenizer.mask_token is not None
  850. and self.original_tokenizer.mask_token_id < self.original_tokenizer.offset
  851. ):
  852. vocab += [(self.original_tokenizer.mask_token, 0.0)]
  853. vocab += [(f"<unk_{i}>", -100.0) for i in range(2, self.original_tokenizer.offset)]
  854. vocab += [(piece.piece, piece.score) for piece in proto.pieces[2:]]
  855. return vocab
  856. def unk_id(self, proto):
  857. return proto.trainer_spec.unk_id + self.original_tokenizer.offset
  858. def pre_tokenizer(self, replacement, add_prefix_space):
  859. prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
  860. return pre_tokenizers.Sequence(
  861. [
  862. pre_tokenizers.WhitespaceSplit(),
  863. pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme),
  864. ]
  865. )
  866. def post_processor(self):
  867. eos = self.original_tokenizer.eos_token
  868. special_tokens = [
  869. (eos, self.original_tokenizer.eos_token_id),
  870. ]
  871. return processors.TemplateProcessing(single=["$A", eos], pair=["$A", "$B", eos], special_tokens=special_tokens)
  872. class T5Converter(SpmConverter):
  873. def vocab(self, proto):
  874. num_extra_ids = self.original_tokenizer._extra_ids
  875. vocab = [(piece.piece, piece.score) for piece in proto.pieces]
  876. vocab += [(f"<extra_id_{i}>", 0.0) for i in range(num_extra_ids - 1, -1, -1)]
  877. return vocab
  878. def post_processor(self):
  879. return processors.TemplateProcessing(
  880. single=["$A", "</s>"],
  881. pair=["$A", "</s>", "$B", "</s>"],
  882. special_tokens=[
  883. ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
  884. ],
  885. )
  886. class UdopConverter(SpmConverter):
  887. def post_processor(self):
  888. return processors.TemplateProcessing(
  889. single=["$A", "</s>"],
  890. pair=["$A", "</s>", "$B", "</s>"],
  891. special_tokens=[
  892. ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
  893. ],
  894. )
  895. class WhisperConverter(Converter):
  896. def converted(self) -> Tokenizer:
  897. vocab = self.original_tokenizer.encoder
  898. merges = list(self.original_tokenizer.bpe_ranks.keys())
  899. tokenizer = Tokenizer(
  900. BPE(
  901. vocab=vocab,
  902. merges=merges,
  903. dropout=None,
  904. continuing_subword_prefix="",
  905. end_of_word_suffix="",
  906. fuse_unk=False,
  907. )
  908. )
  909. tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=self.original_tokenizer.add_prefix_space)
  910. tokenizer.decoder = decoders.ByteLevel()
  911. prefix_token_ids = self.original_tokenizer.prefix_tokens
  912. prefixes = self.original_tokenizer.convert_ids_to_tokens(prefix_token_ids)
  913. eos = self.original_tokenizer.eos_token
  914. eos_token_id = self.original_tokenizer.eos_token_id
  915. prefix_template = " ".join([f"{token}:0" for token in prefixes])
  916. tokenizer.post_processor = processors.TemplateProcessing(
  917. single=f"{prefix_template} $A:0 {eos}:0",
  918. pair=f"{prefix_template} $A:0 $B:1 {eos}:1",
  919. special_tokens=[
  920. (eos, eos_token_id),
  921. *zip(prefixes, prefix_token_ids),
  922. ],
  923. )
  924. return tokenizer
  925. class BigBirdConverter(SpmConverter):
  926. def post_processor(self):
  927. return processors.TemplateProcessing(
  928. single="[CLS]:0 $A:0 [SEP]:0",
  929. pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
  930. special_tokens=[
  931. ("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
  932. ("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
  933. ],
  934. )
  935. class CLIPConverter(Converter):
  936. def converted(self) -> Tokenizer:
  937. vocab = self.original_tokenizer.encoder
  938. merges = list(self.original_tokenizer.bpe_ranks.keys())
  939. unk_token = self.original_tokenizer.unk_token
  940. tokenizer = Tokenizer(
  941. BPE(
  942. vocab=vocab,
  943. merges=merges,
  944. dropout=None,
  945. continuing_subword_prefix="",
  946. end_of_word_suffix="</w>",
  947. fuse_unk=False,
  948. unk_token=str(unk_token),
  949. )
  950. )
  951. tokenizer.normalizer = normalizers.Sequence(
  952. [normalizers.NFC(), normalizers.Replace(Regex(r"\s+"), " "), normalizers.Lowercase()]
  953. )
  954. tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
  955. [
  956. pre_tokenizers.Split(
  957. Regex(r"""'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+"""),
  958. behavior="removed",
  959. invert=True,
  960. ),
  961. pre_tokenizers.ByteLevel(add_prefix_space=False),
  962. ]
  963. )
  964. tokenizer.decoder = decoders.ByteLevel()
  965. # Hack to have a ByteLevel and TemplaceProcessor
  966. tokenizer.post_processor = processors.RobertaProcessing(
  967. sep=(self.original_tokenizer.eos_token, self.original_tokenizer.eos_token_id),
  968. cls=(self.original_tokenizer.bos_token, self.original_tokenizer.bos_token_id),
  969. add_prefix_space=False,
  970. trim_offsets=False,
  971. )
  972. return tokenizer
  973. class LayoutLMv2Converter(Converter):
  974. def converted(self) -> Tokenizer:
  975. vocab = self.original_tokenizer.vocab
  976. tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
  977. tokenize_chinese_chars = False
  978. strip_accents = False
  979. do_lower_case = True
  980. if hasattr(self.original_tokenizer, "basic_tokenizer"):
  981. tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
  982. strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
  983. do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
  984. tokenizer.normalizer = normalizers.BertNormalizer(
  985. clean_text=True,
  986. handle_chinese_chars=tokenize_chinese_chars,
  987. strip_accents=strip_accents,
  988. lowercase=do_lower_case,
  989. )
  990. tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
  991. cls = str(self.original_tokenizer.cls_token)
  992. sep = str(self.original_tokenizer.sep_token)
  993. cls_token_id = self.original_tokenizer.cls_token_id
  994. sep_token_id = self.original_tokenizer.sep_token_id
  995. tokenizer.post_processor = processors.TemplateProcessing(
  996. single=f"{cls}:0 $A:0 {sep}:0",
  997. pair=f"{cls}:0 $A:0 {sep}:0 $B:1 {sep}:1",
  998. special_tokens=[
  999. (cls, cls_token_id),
  1000. (sep, sep_token_id),
  1001. ],
  1002. )
  1003. tokenizer.decoder = decoders.WordPiece(prefix="##")
  1004. return tokenizer
  1005. class BlenderbotConverter(Converter):
  1006. def converted(self) -> Tokenizer:
  1007. ot = self.original_tokenizer
  1008. vocab = ot.encoder
  1009. merges = list(ot.bpe_ranks.keys())
  1010. tokenizer = Tokenizer(
  1011. BPE(
  1012. vocab=vocab,
  1013. merges=merges,
  1014. dropout=None,
  1015. continuing_subword_prefix="",
  1016. end_of_word_suffix="",
  1017. fuse_unk=False,
  1018. )
  1019. )
  1020. tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=ot.add_prefix_space)
  1021. tokenizer.decoder = decoders.ByteLevel()
  1022. tokenizer.post_processor = processors.TemplateProcessing(
  1023. single=f"$A:0 {ot.eos_token}:0",
  1024. special_tokens=[
  1025. (ot.eos_token, ot.eos_token_id),
  1026. ],
  1027. )
  1028. return tokenizer
  1029. class XGLMConverter(SpmConverter):
  1030. def vocab(self, proto):
  1031. vocab = [
  1032. ("<s>", 0.0),
  1033. ("<pad>", 0.0),
  1034. ("</s>", 0.0),
  1035. ("<unk>", 0.0),
  1036. ]
  1037. vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
  1038. vocab += [("<madeupword0>", 0.0), ("<madeupword1>", 0.0), ("<madeupword2>", 0.0), ("<madeupword3>", 0.0), ("<madeupword4>", 0.0), ("<madeupword5>", 0.0), ("<madeupword6>", 0.0)] # fmt: skip
  1039. return vocab
  1040. def unk_id(self, proto):
  1041. unk_id = 3
  1042. return unk_id
  1043. def post_processor(self):
  1044. return processors.TemplateProcessing(
  1045. single="</s> $A",
  1046. pair="</s> $A </s> </s> $B",
  1047. special_tokens=[
  1048. ("<s>", self.original_tokenizer.convert_tokens_to_ids("<s>")),
  1049. ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
  1050. ],
  1051. )
  1052. class GemmaConverter(SpmConverter):
  1053. handle_byte_fallback = True
  1054. SpmExtractor = GemmaSentencePieceExtractor
  1055. # start and end of turn tokens must be marked as special
  1056. special_tokens = {"<start_of_turn>", "<end_of_turn>"}
  1057. """"
  1058. split_by_unicode_script: true
  1059. split_by_number: true
  1060. split_by_whitespace: true
  1061. treat_whitespace_as_suffix: false
  1062. allow_whitespace_only_pieces: true
  1063. split_digits: true
  1064. byte_fallback: true
  1065. """
  1066. def normalizer(self, proto):
  1067. return normalizers.Replace(" ", "▁")
  1068. def vocab(self, proto):
  1069. vocab = [
  1070. (self.original_tokenizer.pad_token, 0.0),
  1071. (self.original_tokenizer.eos_token, 0.0),
  1072. (self.original_tokenizer.bos_token, 0.0),
  1073. ]
  1074. vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
  1075. # Older gemma tokenizers had a missing tab token, so we fix that here
  1076. if not any(x[0] == "\t" for x in vocab):
  1077. override_index = next((i for i, x in enumerate(vocab) if x[0] == "<0x09>"), None)
  1078. if override_index is not None:
  1079. vocab[override_index] = ("\t", 0.0)
  1080. return vocab
  1081. def pre_tokenizer(self, replacement, add_prefix_space):
  1082. return pre_tokenizers.Split(" ", "merged_with_previous")
  1083. def unk_id(self, proto):
  1084. unk_id = 3
  1085. return unk_id
  1086. def decoder(self, replacement, add_prefix_space):
  1087. return decoders.Sequence(
  1088. [
  1089. decoders.Replace("▁", " "),
  1090. decoders.ByteFallback(),
  1091. decoders.Fuse(),
  1092. ]
  1093. )
  1094. class LlamaConverter(SpmConverter):
  1095. handle_byte_fallback = True
  1096. def vocab(self, proto):
  1097. vocab = [
  1098. (self.original_tokenizer.convert_ids_to_tokens(0), 0.0),
  1099. (self.original_tokenizer.convert_ids_to_tokens(1), 0.0),
  1100. (self.original_tokenizer.convert_ids_to_tokens(2), 0.0),
  1101. ]
  1102. vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
  1103. return vocab
  1104. def unk_id(self, proto):
  1105. unk_id = 0
  1106. return unk_id
  1107. def decoder(self, replacement, add_prefix_space):
  1108. sequence = [
  1109. decoders.Replace("▁", " "),
  1110. decoders.ByteFallback(),
  1111. decoders.Fuse(),
  1112. ]
  1113. if add_prefix_space:
  1114. sequence += [decoders.Strip(content=" ", left=1)]
  1115. return decoders.Sequence(sequence)
  1116. def normalizer(self, proto):
  1117. if getattr(self.original_tokenizer, "legacy", True):
  1118. sequence = []
  1119. if getattr(self.original_tokenizer, "add_prefix_space", True):
  1120. sequence += [normalizers.Prepend(prepend="▁")]
  1121. sequence += [normalizers.Replace(pattern=" ", content="▁")]
  1122. return normalizers.Sequence(sequence)
  1123. return None # non-legacy, no normalizer
  1124. def pre_tokenizer(self, replacement, add_prefix_space):
  1125. if not getattr(self.original_tokenizer, "legacy", True): # non-legacy, we need a replace
  1126. prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
  1127. return pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme, split=False)
  1128. return None
  1129. def post_processor(self):
  1130. # the processor is defined in the LlamaTokenizerFast class.
  1131. return None
  1132. class MarkupLMConverter(Converter):
  1133. def converted(self) -> Tokenizer:
  1134. ot = self.original_tokenizer
  1135. vocab = ot.encoder
  1136. merges = list(ot.bpe_ranks.keys())
  1137. tokenizer = Tokenizer(
  1138. BPE(
  1139. vocab=vocab,
  1140. merges=merges,
  1141. dropout=None,
  1142. continuing_subword_prefix="",
  1143. end_of_word_suffix="",
  1144. fuse_unk=False,
  1145. unk_token=self.original_tokenizer.unk_token,
  1146. )
  1147. )
  1148. tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=ot.add_prefix_space)
  1149. tokenizer.decoder = decoders.ByteLevel()
  1150. cls = str(self.original_tokenizer.cls_token)
  1151. sep = str(self.original_tokenizer.sep_token)
  1152. cls_token_id = self.original_tokenizer.cls_token_id
  1153. sep_token_id = self.original_tokenizer.sep_token_id
  1154. tokenizer.post_processor = processors.TemplateProcessing(
  1155. single=f"{cls} $A {sep}",
  1156. pair=f"{cls} $A {sep} $B {sep}",
  1157. special_tokens=[
  1158. (cls, cls_token_id),
  1159. (sep, sep_token_id),
  1160. ],
  1161. )
  1162. return tokenizer
  1163. class MoshiConverter(SpmConverter):
  1164. handle_byte_fallback = True
  1165. def __init__(self, vocab_file, model_max_length=None, **kwargs):
  1166. requires_backends(self, "protobuf")
  1167. Converter.__init__(self, vocab_file)
  1168. # from .utils import sentencepiece_model_pb2 as model_pb2
  1169. model_pb2 = import_protobuf()
  1170. m = model_pb2.ModelProto()
  1171. with open(vocab_file, "rb") as f:
  1172. m.ParseFromString(f.read())
  1173. self.proto = m
  1174. def normalizer(self, proto):
  1175. precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
  1176. _normalizers = [
  1177. normalizers.Replace(" ", "▁"),
  1178. ]
  1179. if not precompiled_charsmap:
  1180. return normalizers.Sequence(_normalizers)
  1181. else:
  1182. return normalizers.Sequence([normalizers.Precompiled(precompiled_charsmap)] + _normalizers)
  1183. def decoder(self, replacement, add_prefix_space):
  1184. sequence = [
  1185. decoders.Replace("▁", " "),
  1186. decoders.ByteFallback(),
  1187. decoders.Fuse(),
  1188. ]
  1189. if add_prefix_space:
  1190. sequence += [decoders.Strip(content=" ", left=1)]
  1191. return decoders.Sequence(sequence)
  1192. def pre_tokenizer(self, replacement, add_prefix_space):
  1193. prepend_scheme = "first"
  1194. return pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme, split=False)
  1195. class HeliumConverter(SpmConverter):
  1196. handle_byte_fallback = True
  1197. def __init__(self, vocab_file=None, **kwargs):
  1198. requires_backends(self, "protobuf")
  1199. Converter.__init__(self, vocab_file)
  1200. model_pb2 = import_protobuf()
  1201. m = model_pb2.ModelProto()
  1202. with open(vocab_file, "rb") as f:
  1203. m.ParseFromString(f.read())
  1204. self.proto = m
  1205. def tokenizer(self, proto):
  1206. vocab_scores = self.vocab(proto)
  1207. tokenizer = Tokenizer(
  1208. Unigram(
  1209. vocab_scores,
  1210. unk_id=self.unk_id(proto),
  1211. byte_fallback=self.handle_byte_fallback,
  1212. )
  1213. )
  1214. # control tokens are special
  1215. # user defined symbols are not
  1216. # both user and control tokens are AddedTokens
  1217. # Add user defined symbols (type == 4) from sentencepiece (https://github.com/google/sentencepiece/blob/6225e08edb2577757163b3f5dbba4c0b670ef445/src/sentencepiece_model.proto#L299C29-L299C33)
  1218. spm_added_tokens = [
  1219. (id, p.piece, p.type == 3 or p.piece in self.special_tokens)
  1220. for id, p in enumerate(proto.pieces)
  1221. if p.type in [3, 4]
  1222. ]
  1223. tokenizer.add_tokens(
  1224. [
  1225. AddedToken(token, normalized=False, special=special, single_word=True)
  1226. for id, token, special in sorted(spm_added_tokens, key=lambda x: x[0])
  1227. ]
  1228. )
  1229. tokenizer.add_tokens([AddedToken("\n", normalized=False, special=False)])
  1230. tokenizer.enable_padding(pad_token="<pad>", pad_id=3)
  1231. return tokenizer
  1232. def vocab(self, proto):
  1233. vocab = []
  1234. for piece in proto.pieces:
  1235. if piece.piece == "<0x0A>":
  1236. vocab += [("\n", piece.score)]
  1237. else:
  1238. vocab += [(piece.piece, piece.score)]
  1239. return vocab
  1240. def unk_id(self, proto):
  1241. unk_id = 0
  1242. return unk_id
  1243. def decoder(self, replacement, add_prefix_space):
  1244. sequence = [
  1245. decoders.Replace("▁", " "),
  1246. decoders.ByteFallback(),
  1247. decoders.Fuse(),
  1248. ]
  1249. sequence += [decoders.Strip(content=" ", left=1)]
  1250. return decoders.Sequence(sequence)
  1251. def normalizer(self, proto):
  1252. return normalizers.Sequence([normalizers.Prepend(" "), normalizers.Replace(r" ", "▁")])
  1253. def pre_tokenizer(self, replacement, add_prefix_space):
  1254. return pre_tokenizers.Sequence([pre_tokenizers.Split("\n", "contiguous")])
  1255. def post_processor(self):
  1256. return processors.TemplateProcessing(
  1257. single=[
  1258. "<s>",
  1259. "$A",
  1260. ],
  1261. pair=[
  1262. "<s>",
  1263. "$A",
  1264. "<s>",
  1265. "$B",
  1266. ],
  1267. special_tokens=[
  1268. ("<s>", 1),
  1269. ],
  1270. )
  1271. class ParakeetConverter(SpmConverter):
  1272. handle_byte_fallback = True
  1273. def __init__(self, vocab_file=None, *args):
  1274. self.vocab_file = vocab_file
  1275. requires_backends(self, "protobuf")
  1276. Converter.__init__(self, vocab_file)
  1277. model_pb2 = import_protobuf()
  1278. m = model_pb2.ModelProto()
  1279. with open(vocab_file, "rb") as f:
  1280. m.ParseFromString(f.read())
  1281. self.proto = m
  1282. def tokenizer(self, proto):
  1283. vocab_scores = self.vocab(proto)
  1284. _, merges = self.SpmExtractor(self.vocab_file).extract(vocab_scores)
  1285. bpe_vocab = {word: i for i, (word, score) in enumerate(vocab_scores)}
  1286. tokenizer = Tokenizer(
  1287. BPE(
  1288. bpe_vocab,
  1289. merges,
  1290. unk_token=proto.trainer_spec.unk_piece,
  1291. fuse_unk=True,
  1292. byte_fallback=self.handle_byte_fallback,
  1293. dropout=None,
  1294. )
  1295. )
  1296. # Add user defined symbols and control tokens from sentencepiece model
  1297. spm_added_tokens = [
  1298. (id, p.piece, p.type == 3 or p.piece in self.special_tokens)
  1299. for id, p in enumerate(proto.pieces)
  1300. if p.type in [3, 4]
  1301. ]
  1302. tokenizer.add_tokens(
  1303. [
  1304. AddedToken(token, normalized=False, special=special)
  1305. for id, token, special in sorted(spm_added_tokens, key=lambda x: x[0])
  1306. ]
  1307. )
  1308. return tokenizer
  1309. # Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode
  1310. def bytes_to_unicode():
  1311. """
  1312. Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
  1313. characters the bpe code barfs on.
  1314. The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
  1315. if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
  1316. decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
  1317. tables between utf-8 bytes and unicode strings.
  1318. """
  1319. bs = (
  1320. list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
  1321. )
  1322. cs = bs[:]
  1323. n = 0
  1324. for b in range(2**8):
  1325. if b not in bs:
  1326. bs.append(b)
  1327. cs.append(2**8 + n)
  1328. n += 1
  1329. cs = [chr(n) for n in cs]
  1330. return dict(zip(bs, cs))
  1331. class TikTokenConverter:
  1332. """
  1333. A general tiktoken converter.
  1334. """
  1335. def __init__(
  1336. self,
  1337. vocab_file=None,
  1338. pattern=r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""",
  1339. add_prefix_space=False,
  1340. additional_special_tokens=None,
  1341. **kwargs,
  1342. ):
  1343. self.vocab_file = vocab_file
  1344. self.pattern = pattern
  1345. self.add_prefix_space = add_prefix_space
  1346. self.additional_special_tokens = (
  1347. additional_special_tokens.keys()
  1348. if isinstance(additional_special_tokens, dict)
  1349. else additional_special_tokens
  1350. )
  1351. def extract_vocab_merges_from_model(self, tiktoken_url: str):
  1352. try:
  1353. from tiktoken.load import load_tiktoken_bpe
  1354. except Exception:
  1355. raise ValueError(
  1356. "`tiktoken` is required to read a `tiktoken` file. Install it with `pip install tiktoken`."
  1357. )
  1358. bpe_ranks = load_tiktoken_bpe(tiktoken_url)
  1359. byte_encoder = bytes_to_unicode()
  1360. def token_bytes_to_string(b):
  1361. return "".join([byte_encoder[ord(char)] for char in b.decode("latin-1")])
  1362. merges = []
  1363. vocab = {}
  1364. for token, rank in bpe_ranks.items():
  1365. vocab[token_bytes_to_string(token)] = rank
  1366. if len(token) == 1:
  1367. continue
  1368. local = []
  1369. for index in range(1, len(token)):
  1370. piece_l, piece_r = token[:index], token[index:]
  1371. if piece_l in bpe_ranks and piece_r in bpe_ranks and (piece_l + piece_r) in bpe_ranks:
  1372. local.append((piece_l, piece_r, rank))
  1373. local = sorted(local, key=lambda x: (bpe_ranks[x[0]], bpe_ranks[x[1]]), reverse=False)
  1374. merges.extend(local)
  1375. merges = sorted(merges, key=lambda val: val[2], reverse=False)
  1376. merges = [(token_bytes_to_string(val[0]), token_bytes_to_string(val[1])) for val in merges]
  1377. return vocab, merges
  1378. def tokenizer(self):
  1379. vocab_scores, merges = self.extract_vocab_merges_from_model(self.vocab_file)
  1380. tokenizer = Tokenizer(BPE(vocab_scores, merges, fuse_unk=False))
  1381. if hasattr(tokenizer.model, "ignore_merges"):
  1382. tokenizer.model.ignore_merges = True
  1383. return tokenizer
  1384. def converted(self) -> Tokenizer:
  1385. tokenizer = self.tokenizer()
  1386. tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
  1387. [
  1388. pre_tokenizers.Split(Regex(self.pattern), behavior="isolated", invert=False),
  1389. pre_tokenizers.ByteLevel(add_prefix_space=self.add_prefix_space, use_regex=False),
  1390. ]
  1391. )
  1392. tokenizer.decoder = decoders.ByteLevel()
  1393. tokenizer.add_special_tokens(
  1394. [AddedToken(token, normalized=False, special=True) for token in self.additional_special_tokens]
  1395. )
  1396. tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
  1397. return tokenizer
  1398. class MistralConverter:
  1399. def __init__(
  1400. self,
  1401. vocab_file=None,
  1402. pattern=r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""",
  1403. add_prefix_space=False,
  1404. additional_special_tokens=None,
  1405. **kwargs,
  1406. ):
  1407. self.vocab_file = vocab_file
  1408. self.pattern = pattern
  1409. self.add_prefix_space = add_prefix_space
  1410. self.additional_special_tokens = (
  1411. additional_special_tokens.keys()
  1412. if isinstance(additional_special_tokens, dict)
  1413. else additional_special_tokens
  1414. )
  1415. def extract_vocab_merges_from_model(self, tiktoken_url: str):
  1416. import base64
  1417. import json
  1418. with open(self.vocab_file, "r", encoding="utf-8") as f:
  1419. untyped = json.load(f)
  1420. self.pattern = untyped["config"]["pattern"]
  1421. self.additional_special_tokens = [
  1422. AddedToken(k["token_str"], special=k["is_control"]) for k in untyped["special_tokens"]
  1423. ]
  1424. bpe_ranks = untyped["vocab"]
  1425. byte_encoder = bytes_to_unicode()
  1426. @lru_cache
  1427. def token_bytes_to_string(b):
  1428. return "".join([byte_encoder[ord(char)] for char in b.decode("latin-1")])
  1429. merges = []
  1430. vocab = {}
  1431. for idx, token in enumerate(self.additional_special_tokens):
  1432. vocab[token.content] = idx
  1433. bpe_ranks = [base64.b64decode(k["token_bytes"]) for k in bpe_ranks]
  1434. rank_set = set(bpe_ranks)
  1435. for rank, token in enumerate(tqdm(bpe_ranks, desc="Converting tekken.json to tokenizer.json")):
  1436. vocab[token_bytes_to_string(token)] = rank
  1437. if len(token) == 1:
  1438. continue
  1439. local = []
  1440. for index in range(1, len(token)):
  1441. piece_l, piece_r = token[:index], token[index:]
  1442. if piece_l in rank_set and piece_r in rank_set and (piece_l + piece_r) in rank_set:
  1443. local.append((piece_l, piece_r, rank))
  1444. local = sorted(local, key=lambda x: (bpe_ranks.index(x[0]), bpe_ranks.index(x[1])), reverse=False)
  1445. merges.extend(local)
  1446. merges = sorted(merges, key=lambda val: val[2], reverse=False)
  1447. merges = [(token_bytes_to_string(val[0]), token_bytes_to_string(val[1])) for val in merges]
  1448. return vocab, merges
  1449. def tokenizer(self):
  1450. vocab_scores, merges = self.extract_vocab_merges_from_model(self.vocab_file)
  1451. tokenizer = Tokenizer(BPE(vocab_scores, merges, fuse_unk=False))
  1452. if hasattr(tokenizer.model, "ignore_merges"):
  1453. tokenizer.model.ignore_merges = True
  1454. return tokenizer
  1455. def converted(self) -> Tokenizer:
  1456. tokenizer = self.tokenizer()
  1457. tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
  1458. [
  1459. pre_tokenizers.Split(Regex(self.pattern), behavior="isolated", invert=False),
  1460. pre_tokenizers.ByteLevel(add_prefix_space=self.add_prefix_space, use_regex=False),
  1461. ]
  1462. )
  1463. tokenizer.decoder = decoders.ByteLevel()
  1464. tokenizer.add_tokens(self.additional_special_tokens)
  1465. tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
  1466. return tokenizer
  1467. SLOW_TO_FAST_CONVERTERS = {
  1468. "AlbertTokenizer": AlbertConverter,
  1469. "BartTokenizer": RobertaConverter,
  1470. "BarthezTokenizer": BarthezConverter,
  1471. "BertTokenizer": BertConverter,
  1472. "BigBirdTokenizer": BigBirdConverter,
  1473. "BlenderbotTokenizer": BlenderbotConverter,
  1474. "CamembertTokenizer": CamembertConverter,
  1475. "CLIPTokenizer": CLIPConverter,
  1476. "CodeGenTokenizer": GPT2Converter,
  1477. "ConvBertTokenizer": BertConverter,
  1478. "DebertaTokenizer": DebertaConverter,
  1479. "DebertaV2Tokenizer": DebertaV2Converter,
  1480. "DistilBertTokenizer": BertConverter,
  1481. "DPRReaderTokenizer": BertConverter,
  1482. "DPRQuestionEncoderTokenizer": BertConverter,
  1483. "DPRContextEncoderTokenizer": BertConverter,
  1484. "ElectraTokenizer": BertConverter,
  1485. "FNetTokenizer": AlbertConverter,
  1486. "FunnelTokenizer": FunnelConverter,
  1487. "GPT2Tokenizer": GPT2Converter,
  1488. "HerbertTokenizer": HerbertConverter,
  1489. "LayoutLMTokenizer": BertConverter,
  1490. "LayoutLMv2Tokenizer": BertConverter,
  1491. "LayoutLMv3Tokenizer": RobertaConverter,
  1492. "LayoutXLMTokenizer": XLMRobertaConverter,
  1493. "LongformerTokenizer": RobertaConverter,
  1494. "LEDTokenizer": RobertaConverter,
  1495. "LxmertTokenizer": BertConverter,
  1496. "MarkupLMTokenizer": MarkupLMConverter,
  1497. "MBartTokenizer": MBartConverter,
  1498. "MBart50Tokenizer": MBart50Converter,
  1499. "MPNetTokenizer": MPNetConverter,
  1500. "MobileBertTokenizer": BertConverter,
  1501. "MvpTokenizer": RobertaConverter,
  1502. "NllbTokenizer": NllbConverter,
  1503. "OpenAIGPTTokenizer": OpenAIGPTConverter,
  1504. "PegasusTokenizer": PegasusConverter,
  1505. "Qwen2Tokenizer": Qwen2Converter,
  1506. "RealmTokenizer": BertConverter,
  1507. "ReformerTokenizer": ReformerConverter,
  1508. "RemBertTokenizer": RemBertConverter,
  1509. "RetriBertTokenizer": BertConverter,
  1510. "RobertaTokenizer": RobertaConverter,
  1511. "RoFormerTokenizer": RoFormerConverter,
  1512. "SeamlessM4TTokenizer": SeamlessM4TConverter,
  1513. "SqueezeBertTokenizer": BertConverter,
  1514. "T5Tokenizer": T5Converter,
  1515. "UdopTokenizer": UdopConverter,
  1516. "WhisperTokenizer": WhisperConverter,
  1517. "XLMRobertaTokenizer": XLMRobertaConverter,
  1518. "XLNetTokenizer": XLNetConverter,
  1519. "SplinterTokenizer": SplinterConverter,
  1520. "XGLMTokenizer": XGLMConverter,
  1521. "LlamaTokenizer": LlamaConverter,
  1522. "CodeLlamaTokenizer": LlamaConverter,
  1523. "GemmaTokenizer": GemmaConverter,
  1524. "Phi3Tokenizer": LlamaConverter,
  1525. }
  1526. def convert_slow_tokenizer(transformer_tokenizer, from_tiktoken=False) -> Tokenizer:
  1527. """
  1528. Utilities to convert a slow tokenizer instance in a fast tokenizer instance.
  1529. Args:
  1530. transformer_tokenizer ([`~tokenization_utils_base.PreTrainedTokenizer`]):
  1531. Instance of a slow tokenizer to convert in the backend tokenizer for
  1532. [`~tokenization_utils_base.PreTrainedTokenizerFast`].
  1533. from_tiktoken (bool, optional): Whether to use the `tiktoken` library to convert the tokenizer instead of sentencepiece.
  1534. Defaults to False.
  1535. Return:
  1536. A instance of [`~tokenizers.Tokenizer`] to be used as the backend tokenizer of a
  1537. [`~tokenization_utils_base.PreTrainedTokenizerFast`]
  1538. """
  1539. tokenizer_class_name = transformer_tokenizer.__class__.__name__
  1540. if tokenizer_class_name in SLOW_TO_FAST_CONVERTERS and not from_tiktoken:
  1541. converter_class = SLOW_TO_FAST_CONVERTERS[tokenizer_class_name]
  1542. return converter_class(transformer_tokenizer).converted()
  1543. elif transformer_tokenizer.vocab_file.endswith("tekken.json"):
  1544. transformer_tokenizer.original_tokenizer = transformer_tokenizer
  1545. logger.info("Converting from Mistral tekken.json")
  1546. return MistralConverter(transformer_tokenizer.vocab_file).converted()
  1547. else:
  1548. try:
  1549. logger.info("Converting from Tiktoken")
  1550. return TikTokenConverter(
  1551. vocab_file=transformer_tokenizer.vocab_file,
  1552. additional_special_tokens=transformer_tokenizer.additional_special_tokens,
  1553. ).converted()
  1554. except Exception:
  1555. raise ValueError(
  1556. f"Converting from SentencePiece and Tiktoken failed, if a converter for SentencePiece is available, provide a model path "
  1557. f"with a SentencePiece tokenizer.model file."
  1558. f"Currently available slow->fast converters: {list(SLOW_TO_FAST_CONVERTERS.keys())}"
  1559. )