__init__.py 48 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230
  1. # This file was automatically generated by SWIG (https://www.swig.org).
  2. # Version 4.3.0
  3. #
  4. # Do not make changes to this file unless you know what you are doing - modify
  5. # the SWIG interface file instead.
  6. from sys import version_info as _swig_python_version_info
  7. # Import the low-level C/C++ module
  8. if __package__ or "." in __name__:
  9. from . import _sentencepiece
  10. else:
  11. import _sentencepiece
  12. try:
  13. import builtins as __builtin__
  14. except ImportError:
  15. import __builtin__
  16. def _swig_repr(self):
  17. try:
  18. strthis = "proxy of " + self.this.__repr__()
  19. except __builtin__.Exception:
  20. strthis = ""
  21. return "<%s.%s; %s >" % (self.__class__.__module__, self.__class__.__name__, strthis,)
  22. def _swig_setattr_nondynamic_instance_variable(set):
  23. def set_instance_attr(self, name, value):
  24. if name == "this":
  25. set(self, name, value)
  26. elif name == "thisown":
  27. self.this.own(value)
  28. elif hasattr(self, name) and isinstance(getattr(type(self), name), property):
  29. set(self, name, value)
  30. else:
  31. raise AttributeError("You cannot add instance attributes to %s" % self)
  32. return set_instance_attr
  33. def _swig_setattr_nondynamic_class_variable(set):
  34. def set_class_attr(cls, name, value):
  35. if hasattr(cls, name) and not isinstance(getattr(cls, name), property):
  36. set(cls, name, value)
  37. else:
  38. raise AttributeError("You cannot add class attributes to %s" % cls)
  39. return set_class_attr
  40. def _swig_add_metaclass(metaclass):
  41. """Class decorator for adding a metaclass to a SWIG wrapped class - a slimmed down version of six.add_metaclass"""
  42. def wrapper(cls):
  43. return metaclass(cls.__name__, cls.__bases__, cls.__dict__.copy())
  44. return wrapper
  45. class _SwigNonDynamicMeta(type):
  46. """Meta class to enforce nondynamic attributes (no new attributes) for a class"""
  47. __setattr__ = _swig_setattr_nondynamic_class_variable(type.__setattr__)
  48. class ImmutableSentencePieceText_ImmutableSentencePiece(object):
  49. thisown = property(lambda x: x.this.own(), lambda x, v: x.this.own(v), doc="The membership flag")
  50. __repr__ = _swig_repr
  51. def __init__(self):
  52. _sentencepiece.ImmutableSentencePieceText_ImmutableSentencePiece_swiginit(self, _sentencepiece.new_ImmutableSentencePieceText_ImmutableSentencePiece())
  53. __swig_destroy__ = _sentencepiece.delete_ImmutableSentencePieceText_ImmutableSentencePiece
  54. def _piece(self):
  55. return _sentencepiece.ImmutableSentencePieceText_ImmutableSentencePiece__piece(self)
  56. def _surface(self):
  57. return _sentencepiece.ImmutableSentencePieceText_ImmutableSentencePiece__surface(self)
  58. def _id(self):
  59. return _sentencepiece.ImmutableSentencePieceText_ImmutableSentencePiece__id(self)
  60. def _begin(self):
  61. return _sentencepiece.ImmutableSentencePieceText_ImmutableSentencePiece__begin(self)
  62. def _end(self):
  63. return _sentencepiece.ImmutableSentencePieceText_ImmutableSentencePiece__end(self)
  64. def _surface_as_bytes(self):
  65. return _sentencepiece.ImmutableSentencePieceText_ImmutableSentencePiece__surface_as_bytes(self)
  66. def _piece_as_bytes(self):
  67. return _sentencepiece.ImmutableSentencePieceText_ImmutableSentencePiece__piece_as_bytes(self)
  68. piece = property(_piece)
  69. piece_as_bytes = property(_piece_as_bytes)
  70. surface = property(_surface)
  71. surface_as_bytes = property(_surface_as_bytes)
  72. id = property(_id)
  73. begin = property(_begin)
  74. end = property(_end)
  75. def __str__(self):
  76. return ('piece: \"{}\"\n'
  77. 'id: {}\n'
  78. 'surface: \"{}\"\n'
  79. 'begin: {}\n'
  80. 'end: {}\n').format(self.piece, self.id, self.surface,
  81. self.begin, self.end)
  82. def __eq__(self, other):
  83. return self.piece == other.piece and self.id == other.id and self.surface == other.surface and self.begin == other.begin and self.end == other.end
  84. def __hash__(self):
  85. return hash(str(self))
  86. __repr__ = __str__
  87. # Register ImmutableSentencePieceText_ImmutableSentencePiece in _sentencepiece:
  88. _sentencepiece.ImmutableSentencePieceText_ImmutableSentencePiece_swigregister(ImmutableSentencePieceText_ImmutableSentencePiece)
  89. class ImmutableSentencePieceText(object):
  90. thisown = property(lambda x: x.this.own(), lambda x, v: x.this.own(v), doc="The membership flag")
  91. __repr__ = _swig_repr
  92. def __init__(self):
  93. _sentencepiece.ImmutableSentencePieceText_swiginit(self, _sentencepiece.new_ImmutableSentencePieceText())
  94. __swig_destroy__ = _sentencepiece.delete_ImmutableSentencePieceText
  95. def _pieces_size(self):
  96. return _sentencepiece.ImmutableSentencePieceText__pieces_size(self)
  97. def _pieces(self, index):
  98. return _sentencepiece.ImmutableSentencePieceText__pieces(self, index)
  99. def _text(self):
  100. return _sentencepiece.ImmutableSentencePieceText__text(self)
  101. def _score(self):
  102. return _sentencepiece.ImmutableSentencePieceText__score(self)
  103. def SerializeAsString(self):
  104. return _sentencepiece.ImmutableSentencePieceText_SerializeAsString(self)
  105. def _text_as_bytes(self):
  106. return _sentencepiece.ImmutableSentencePieceText__text_as_bytes(self)
  107. text = property(_text)
  108. text_as_bytes = property(_text_as_bytes)
  109. score = property(_score)
  110. class ImmutableSentencePieceIterator:
  111. def __init__(self, proto):
  112. self.proto = proto
  113. self.len = self.proto._pieces_size()
  114. def __len__(self):
  115. return self.len
  116. def __getitem__(self, index):
  117. if isinstance(index, slice):
  118. return [self.proto._pieces(i) for i in range(self.len)][index.start:index.stop:index.step]
  119. if index < 0:
  120. index = index + self.len
  121. if index < 0 or index >= self.len:
  122. raise IndexError('piece index is out of range')
  123. return self.proto._pieces(index)
  124. def __str__(self):
  125. return '\n'.join(['pieces {{\n{}}}'.format(str(x)) for x in self])
  126. __repr__ = __str__
  127. @property
  128. def pieces(self):
  129. return ImmutableSentencePieceText.ImmutableSentencePieceIterator(self)
  130. def __eq__(self, other):
  131. return self.SerializeAsString() == other.SerializeAsString()
  132. def __hash__(self):
  133. return hash(self.SerializeAsString())
  134. def __str__(self):
  135. return ('text: \"{}\"\n'
  136. 'score: {}\n'
  137. '{}').format(self.text, self.score,
  138. '\n'.join(['pieces {{\n{}}}'.format(str(x)) for x in self.pieces]))
  139. __repr__ = __str__
  140. # Register ImmutableSentencePieceText in _sentencepiece:
  141. _sentencepiece.ImmutableSentencePieceText_swigregister(ImmutableSentencePieceText)
  142. class ImmutableNBestSentencePieceText(object):
  143. thisown = property(lambda x: x.this.own(), lambda x, v: x.this.own(v), doc="The membership flag")
  144. __repr__ = _swig_repr
  145. def __init__(self):
  146. _sentencepiece.ImmutableNBestSentencePieceText_swiginit(self, _sentencepiece.new_ImmutableNBestSentencePieceText())
  147. __swig_destroy__ = _sentencepiece.delete_ImmutableNBestSentencePieceText
  148. def _nbests_size(self):
  149. return _sentencepiece.ImmutableNBestSentencePieceText__nbests_size(self)
  150. def _nbests(self, index):
  151. return _sentencepiece.ImmutableNBestSentencePieceText__nbests(self, index)
  152. def SerializeAsString(self):
  153. return _sentencepiece.ImmutableNBestSentencePieceText_SerializeAsString(self)
  154. class ImmutableSentencePieceTextIterator:
  155. def __init__(self, proto):
  156. self.proto = proto
  157. self.len = self.proto._nbests_size()
  158. def __len__(self):
  159. return self.len
  160. def __getitem__(self, index):
  161. if isinstance(index, slice):
  162. return [self.proto._nbests(i) for i in range(self.len)][index.start:index.stop:index.step]
  163. if index < 0:
  164. index = index + self.len
  165. if index < 0 or index >= self.len:
  166. raise IndexError('nbests index is out of range')
  167. return self.proto._nbests(index)
  168. def __str__(self):
  169. return '\n'.join(['nbests {{\n{}}}'.format(str(x)) for x in self])
  170. __repr__ = __str__
  171. @property
  172. def nbests(self):
  173. return ImmutableNBestSentencePieceText.ImmutableSentencePieceTextIterator(self)
  174. def __eq__(self, other):
  175. return self.SerializeAsString() == other.SerializeAsString()
  176. def __hash__(self):
  177. return hash(self.SerializeAsString())
  178. def __str__(self):
  179. return '\n'.join(['nbests {{\n{}}}'.format(str(x)) for x in self.nbests])
  180. __repr__ = __str__
  181. # Register ImmutableNBestSentencePieceText in _sentencepiece:
  182. _sentencepiece.ImmutableNBestSentencePieceText_swigregister(ImmutableNBestSentencePieceText)
  183. class SentencePieceProcessor(object):
  184. thisown = property(lambda x: x.this.own(), lambda x, v: x.this.own(v), doc="The membership flag")
  185. __repr__ = _swig_repr
  186. def __init__(self):
  187. _sentencepiece.SentencePieceProcessor_swiginit(self, _sentencepiece.new_SentencePieceProcessor())
  188. __swig_destroy__ = _sentencepiece.delete_SentencePieceProcessor
  189. def LoadFromSerializedProto(self, serialized):
  190. return _sentencepiece.SentencePieceProcessor_LoadFromSerializedProto(self, serialized)
  191. def SetEncodeExtraOptions(self, extra_option):
  192. return _sentencepiece.SentencePieceProcessor_SetEncodeExtraOptions(self, extra_option)
  193. def SetDecodeExtraOptions(self, extra_option):
  194. return _sentencepiece.SentencePieceProcessor_SetDecodeExtraOptions(self, extra_option)
  195. def SetVocabulary(self, valid_vocab):
  196. return _sentencepiece.SentencePieceProcessor_SetVocabulary(self, valid_vocab)
  197. def ResetVocabulary(self):
  198. return _sentencepiece.SentencePieceProcessor_ResetVocabulary(self)
  199. def LoadVocabulary(self, filename, threshold):
  200. return _sentencepiece.SentencePieceProcessor_LoadVocabulary(self, filename, threshold)
  201. def CalculateEntropy(self, *args):
  202. return _sentencepiece.SentencePieceProcessor_CalculateEntropy(self, *args)
  203. def GetPieceSize(self):
  204. return _sentencepiece.SentencePieceProcessor_GetPieceSize(self)
  205. def PieceToId(self, piece):
  206. return _sentencepiece.SentencePieceProcessor_PieceToId(self, piece)
  207. def IdToPiece(self, id):
  208. return _sentencepiece.SentencePieceProcessor_IdToPiece(self, id)
  209. def GetScore(self, id):
  210. return _sentencepiece.SentencePieceProcessor_GetScore(self, id)
  211. def IsUnknown(self, id):
  212. return _sentencepiece.SentencePieceProcessor_IsUnknown(self, id)
  213. def IsControl(self, id):
  214. return _sentencepiece.SentencePieceProcessor_IsControl(self, id)
  215. def IsUnused(self, id):
  216. return _sentencepiece.SentencePieceProcessor_IsUnused(self, id)
  217. def IsByte(self, id):
  218. return _sentencepiece.SentencePieceProcessor_IsByte(self, id)
  219. def unk_id(self):
  220. return _sentencepiece.SentencePieceProcessor_unk_id(self)
  221. def bos_id(self):
  222. return _sentencepiece.SentencePieceProcessor_bos_id(self)
  223. def eos_id(self):
  224. return _sentencepiece.SentencePieceProcessor_eos_id(self)
  225. def pad_id(self):
  226. return _sentencepiece.SentencePieceProcessor_pad_id(self)
  227. def serialized_model_proto(self):
  228. return _sentencepiece.SentencePieceProcessor_serialized_model_proto(self)
  229. def LoadFromFile(self, arg):
  230. return _sentencepiece.SentencePieceProcessor_LoadFromFile(self, arg)
  231. def _EncodeAsIds(self, text, enable_sampling, nbest_size, alpha, add_bos, add_eos, reverse, emit_unk_piece):
  232. return _sentencepiece.SentencePieceProcessor__EncodeAsIds(self, text, enable_sampling, nbest_size, alpha, add_bos, add_eos, reverse, emit_unk_piece)
  233. def _EncodeAsPieces(self, text, enable_sampling, nbest_size, alpha, add_bos, add_eos, reverse, emit_unk_piece):
  234. return _sentencepiece.SentencePieceProcessor__EncodeAsPieces(self, text, enable_sampling, nbest_size, alpha, add_bos, add_eos, reverse, emit_unk_piece)
  235. def _EncodeAsSerializedProto(self, text, enable_sampling, nbest_size, alpha, add_bos, add_eos, reverse, emit_unk_piece):
  236. return _sentencepiece.SentencePieceProcessor__EncodeAsSerializedProto(self, text, enable_sampling, nbest_size, alpha, add_bos, add_eos, reverse, emit_unk_piece)
  237. def _EncodeAsImmutableProto(self, text, enable_sampling, nbest_size, alpha, add_bos, add_eos, reverse, emit_unk_piece):
  238. return _sentencepiece.SentencePieceProcessor__EncodeAsImmutableProto(self, text, enable_sampling, nbest_size, alpha, add_bos, add_eos, reverse, emit_unk_piece)
  239. def _EncodeAsIdsBatch(self, ins, num_threads, enable_sampling, nbest_size, alpha, add_bos, add_eos, reverse, emit_unk_piece):
  240. return _sentencepiece.SentencePieceProcessor__EncodeAsIdsBatch(self, ins, num_threads, enable_sampling, nbest_size, alpha, add_bos, add_eos, reverse, emit_unk_piece)
  241. def _EncodeAsPiecesBatch(self, ins, num_threads, enable_sampling, nbest_size, alpha, add_bos, add_eos, reverse, emit_unk_piece):
  242. return _sentencepiece.SentencePieceProcessor__EncodeAsPiecesBatch(self, ins, num_threads, enable_sampling, nbest_size, alpha, add_bos, add_eos, reverse, emit_unk_piece)
  243. def _EncodeAsSerializedProtoBatch(self, ins, num_threads, enable_sampling, nbest_size, alpha, add_bos, add_eos, reverse, emit_unk_piece):
  244. return _sentencepiece.SentencePieceProcessor__EncodeAsSerializedProtoBatch(self, ins, num_threads, enable_sampling, nbest_size, alpha, add_bos, add_eos, reverse, emit_unk_piece)
  245. def _EncodeAsImmutableProtoBatch(self, ins, num_threads, enable_sampling, nbest_size, alpha, add_bos, add_eos, reverse, emit_unk_piece):
  246. return _sentencepiece.SentencePieceProcessor__EncodeAsImmutableProtoBatch(self, ins, num_threads, enable_sampling, nbest_size, alpha, add_bos, add_eos, reverse, emit_unk_piece)
  247. def _DecodeIds(self, ids):
  248. return _sentencepiece.SentencePieceProcessor__DecodeIds(self, ids)
  249. def _DecodeIdsAsBytes(self, ids):
  250. return _sentencepiece.SentencePieceProcessor__DecodeIdsAsBytes(self, ids)
  251. def _DecodePieces(self, pieces):
  252. return _sentencepiece.SentencePieceProcessor__DecodePieces(self, pieces)
  253. def _DecodeIdsAsSerializedProto(self, ids):
  254. return _sentencepiece.SentencePieceProcessor__DecodeIdsAsSerializedProto(self, ids)
  255. def _DecodePiecesAsSerializedProto(self, pieces):
  256. return _sentencepiece.SentencePieceProcessor__DecodePiecesAsSerializedProto(self, pieces)
  257. def _DecodeIdsAsImmutableProto(self, ids):
  258. return _sentencepiece.SentencePieceProcessor__DecodeIdsAsImmutableProto(self, ids)
  259. def _DecodePiecesAsImmutableProto(self, pieces):
  260. return _sentencepiece.SentencePieceProcessor__DecodePiecesAsImmutableProto(self, pieces)
  261. def _DecodeIdsBatch(self, ins, num_threads):
  262. return _sentencepiece.SentencePieceProcessor__DecodeIdsBatch(self, ins, num_threads)
  263. def _DecodeIdsAsBytesBatch(self, ins, num_threads):
  264. return _sentencepiece.SentencePieceProcessor__DecodeIdsAsBytesBatch(self, ins, num_threads)
  265. def _DecodeIdsAsSerializedProtoBatch(self, ins, num_threads):
  266. return _sentencepiece.SentencePieceProcessor__DecodeIdsAsSerializedProtoBatch(self, ins, num_threads)
  267. def _DecodeIdsAsImmutableProtoBatch(self, ins, num_threads):
  268. return _sentencepiece.SentencePieceProcessor__DecodeIdsAsImmutableProtoBatch(self, ins, num_threads)
  269. def _DecodePiecesBatch(self, ins, num_threads):
  270. return _sentencepiece.SentencePieceProcessor__DecodePiecesBatch(self, ins, num_threads)
  271. def _DecodePiecesAsSerializedProtoBatch(self, ins, num_threads):
  272. return _sentencepiece.SentencePieceProcessor__DecodePiecesAsSerializedProtoBatch(self, ins, num_threads)
  273. def _DecodePiecesAsImmutableProtoBatch(self, ins, num_threads):
  274. return _sentencepiece.SentencePieceProcessor__DecodePiecesAsImmutableProtoBatch(self, ins, num_threads)
  275. def _NBestEncodeAsIds(self, text, nbest_size, add_bos, add_eos, reverse, emit_unk_piece):
  276. return _sentencepiece.SentencePieceProcessor__NBestEncodeAsIds(self, text, nbest_size, add_bos, add_eos, reverse, emit_unk_piece)
  277. def _NBestEncodeAsPieces(self, text, nbest_size, add_bos, add_eos, reverse, emit_unk_piece):
  278. return _sentencepiece.SentencePieceProcessor__NBestEncodeAsPieces(self, text, nbest_size, add_bos, add_eos, reverse, emit_unk_piece)
  279. def _NBestEncodeAsSerializedProto(self, text, nbest_size, add_bos, add_eos, reverse, emit_unk_piece):
  280. return _sentencepiece.SentencePieceProcessor__NBestEncodeAsSerializedProto(self, text, nbest_size, add_bos, add_eos, reverse, emit_unk_piece)
  281. def _NBestEncodeAsImmutableProto(self, text, nbest_size, add_bos, add_eos, reverse, emit_unk_piece):
  282. return _sentencepiece.SentencePieceProcessor__NBestEncodeAsImmutableProto(self, text, nbest_size, add_bos, add_eos, reverse, emit_unk_piece)
  283. def _SampleEncodeAndScoreAsIds(self, text, num_samples, alpha, wor, include_best, add_bos, add_eos, reverse, emit_unk_piece):
  284. return _sentencepiece.SentencePieceProcessor__SampleEncodeAndScoreAsIds(self, text, num_samples, alpha, wor, include_best, add_bos, add_eos, reverse, emit_unk_piece)
  285. def _SampleEncodeAndScoreAsPieces(self, text, num_samples, alpha, wor, include_best, add_bos, add_eos, reverse, emit_unk_piece):
  286. return _sentencepiece.SentencePieceProcessor__SampleEncodeAndScoreAsPieces(self, text, num_samples, alpha, wor, include_best, add_bos, add_eos, reverse, emit_unk_piece)
  287. def _SampleEncodeAndScoreAsSerializedProto(self, text, num_samples, alpha, wor, include_best, add_bos, add_eos, reverse, emit_unk_piece):
  288. return _sentencepiece.SentencePieceProcessor__SampleEncodeAndScoreAsSerializedProto(self, text, num_samples, alpha, wor, include_best, add_bos, add_eos, reverse, emit_unk_piece)
  289. def _SampleEncodeAndScoreAsImmutableProto(self, text, num_samples, alpha, wor, include_best, add_bos, add_eos, reverse, emit_unk_piece):
  290. return _sentencepiece.SentencePieceProcessor__SampleEncodeAndScoreAsImmutableProto(self, text, num_samples, alpha, wor, include_best, add_bos, add_eos, reverse, emit_unk_piece)
  291. def _Normalize(self, text):
  292. return _sentencepiece.SentencePieceProcessor__Normalize(self, text)
  293. def _NormalizeWithOffsets(self, text):
  294. return _sentencepiece.SentencePieceProcessor__NormalizeWithOffsets(self, text)
  295. def _CalculateEntropy(self, text, alpha):
  296. return _sentencepiece.SentencePieceProcessor__CalculateEntropy(self, text, alpha)
  297. def _CalculateEntropyBatch(self, ins, alpha, num_threads):
  298. return _sentencepiece.SentencePieceProcessor__CalculateEntropyBatch(self, ins, alpha, num_threads)
  299. def _OverrideNormalizerSpec(self, args):
  300. return _sentencepiece.SentencePieceProcessor__OverrideNormalizerSpec(self, args)
  301. def Init(self,
  302. model_file=None,
  303. model_proto=None,
  304. out_type=int,
  305. add_bos=False,
  306. add_eos=False,
  307. reverse=False,
  308. emit_unk_piece=False,
  309. enable_sampling=False,
  310. nbest_size=-1,
  311. alpha=0.1,
  312. num_threads=-1):
  313. """Initialzie sentencepieceProcessor.
  314. Args:
  315. model_file: The sentencepiece model file path.
  316. model_proto: The sentencepiece model serialized proto.
  317. out_type: output type. int or str.
  318. add_bos: Add <s> to the result (Default = false)
  319. add_eos: Add </s> to the result (Default = false) <s>/</s> is added after
  320. reversing (if enabled).
  321. reverse: Reverses the tokenized sequence (Default = false)
  322. emit_unk_piece: Emits the unk literal string (Default = false)
  323. nbest_size: sampling parameters for unigram. Invalid in BPE-Dropout.
  324. nbest_size = {0,1}: No sampling is performed.
  325. nbest_size > 1: samples from the nbest_size results.
  326. nbest_size < 0: assuming that nbest_size is infinite and samples
  327. from the all hypothesis (lattice) using
  328. forward-filtering-and-backward-sampling algorithm.
  329. alpha: Soothing parameter for unigram sampling, and dropout probability of
  330. merge operations for BPE-dropout.
  331. num_threads: number of threads in batch processing (Default = -1, auto-detected)
  332. """
  333. _sentencepiece_processor_init_native(self)
  334. self._out_type = out_type
  335. self._add_bos = add_bos
  336. self._add_eos = add_eos
  337. self._reverse = reverse
  338. self._emit_unk_piece = emit_unk_piece
  339. self._enable_sampling = enable_sampling
  340. self._nbest_size = nbest_size
  341. self._alpha = alpha
  342. self._num_threads = num_threads
  343. if model_file or model_proto:
  344. self.Load(model_file=model_file, model_proto=model_proto)
  345. def Encode(self,
  346. input,
  347. out_type=None,
  348. add_bos=None,
  349. add_eos=None,
  350. reverse=None,
  351. emit_unk_piece=None,
  352. enable_sampling=None,
  353. nbest_size=None,
  354. alpha=None,
  355. num_threads=None):
  356. """Encode text input to segmented ids or tokens.
  357. Args:
  358. input: input string. accepsts list of string.
  359. out_type: output type. int or str.
  360. add_bos: Add <s> to the result (Default = false)
  361. add_eos: Add </s> to the result (Default = false) <s>/</s> is added after
  362. reversing (if enabled).
  363. reverse: Reverses the tokenized sequence (Default = false)
  364. emit_unk_piece: Emits the unk literal string (Default = false)
  365. nbest_size: sampling parameters for unigram. Invalid in BPE-Dropout.
  366. nbest_size = {0,1}: No sampling is performed.
  367. nbest_size > 1: samples from the nbest_size results.
  368. nbest_size < 0: assuming that nbest_size is infinite and samples
  369. from the all hypothesis (lattice) using
  370. forward-filtering-and-backward-sampling algorithm.
  371. alpha: Soothing parameter for unigram sampling, and merge probability for
  372. BPE-dropout (probablity 'p' in BPE-dropout paper).
  373. num_threads: the number of threads used in the batch processing (Default = -1).
  374. """
  375. if out_type is None:
  376. out_type = self._out_type
  377. if add_bos is None:
  378. add_bos = self._add_bos
  379. if add_eos is None:
  380. add_eos = self._add_eos
  381. if reverse is None:
  382. reverse = self._reverse
  383. if emit_unk_piece is None:
  384. emit_unk_piece = self._emit_unk_piece
  385. if enable_sampling is None:
  386. enable_sampling = self._enable_sampling
  387. if nbest_size is None:
  388. nbest_size = self._nbest_size
  389. if alpha is None:
  390. alpha = self._alpha
  391. if num_threads is None:
  392. num_threads = self._num_threads
  393. if enable_sampling == True and (nbest_size is None or nbest_size == 0 or
  394. nbest_size == 1 or alpha is None):
  395. raise RuntimeError(
  396. 'When enable_sampling is True, We must specify "nbest_size > 1" or "nbest_size = -1", '
  397. 'and "alpha". "nbest_size" is enabled only on unigram mode ignored in BPE-dropout. '
  398. 'when "nbest_size = -1" , this method samples from all candidates on the lattice '
  399. 'instead of nbest segmentations.'
  400. )
  401. if num_threads is None or type(num_threads) is not int:
  402. raise RuntimeError('num_threads must be int')
  403. if type(input) is list:
  404. if out_type is int:
  405. return self._EncodeAsIdsBatch(input, num_threads, enable_sampling, nbest_size,
  406. alpha, add_bos, add_eos, reverse, emit_unk_piece)
  407. if out_type is str:
  408. return self._EncodeAsPiecesBatch(input, num_threads, enable_sampling, nbest_size,
  409. alpha, add_bos, add_eos, reverse, emit_unk_piece)
  410. if out_type == 'serialized_proto' or out_type == 'proto':
  411. return self._EncodeAsSerializedProtoBatch(input, num_threads, enable_sampling, nbest_size,
  412. alpha, add_bos, add_eos, reverse, emit_unk_piece)
  413. if out_type == 'immutable_proto':
  414. return self._EncodeAsImmutableProtoBatch(input, num_threads, enable_sampling, nbest_size,
  415. alpha, add_bos, add_eos, reverse, emit_unk_piece)
  416. if out_type is int:
  417. return self._EncodeAsIds(input, enable_sampling, nbest_size,
  418. alpha, add_bos, add_eos, reverse, emit_unk_piece)
  419. if out_type is str:
  420. return self._EncodeAsPieces(input, enable_sampling, nbest_size,
  421. alpha, add_bos, add_eos, reverse, emit_unk_piece)
  422. if out_type == 'serialized_proto' or out_type == 'proto':
  423. return self._EncodeAsSerializedProto(input, enable_sampling, nbest_size,
  424. alpha, add_bos, add_eos, reverse, emit_unk_piece)
  425. if out_type == 'immutable_proto':
  426. return self._EncodeAsImmutableProto(input, enable_sampling, nbest_size,
  427. alpha, add_bos, add_eos, reverse, emit_unk_piece)
  428. raise RuntimeError('unknown out_type={}'.format(out_type))
  429. return None
  430. def EncodeAsPieces(self, input, **kwargs):
  431. return self.Encode(input=input, out_type=str, **kwargs)
  432. def EncodeAsIds(self, input, **kwargs):
  433. return self.Encode(input=input, out_type=int, **kwargs)
  434. def EncodeAsSerializedProto(self, input, **kwargs):
  435. return self.Encode(input=input, out_type='serialized_proto', **kwargs)
  436. def EncodeAsImmutableProto(self, input, **kwargs):
  437. return self.Encode(input=input, out_type='immutable_proto', **kwargs)
  438. def SampleEncodeAsPieces(self, input, nbest_size=None, alpha=None, **kwargs):
  439. return self.Encode(input=input, nbest_size=nbest_size, alpha=alpha,
  440. out_type=str, enable_sampling=True, **kwargs)
  441. def SampleEncodeAsIds(self, input, nbest_size=None, alpha=None,**kwargs):
  442. return self.Encode(input=input, nbest_size=nbest_size, alpha=alpha,
  443. out_type=int, enable_sampling=True, **kwargs)
  444. def SampleEncodeAsSerializedProto(self, input, nbest_size=None, alpha=None, **kwargs):
  445. return self.Encode(input=input, nbest_size=nbest_size, alpha=alpha,
  446. out_type='serialized_proto', enable_sampling=True, **kwargs)
  447. def SampleEncodeAsImmutableProto(self, input, nbest_size=None, alpha=None, **kwargs):
  448. return self.Encode(input=input, nbest_size=nbest_size, alpha=alpha,
  449. out_type='immutable_proto', enable_sampling=True, **kwargs)
  450. def NBestEncode(self,
  451. input,
  452. out_type=None,
  453. add_bos=None,
  454. add_eos=None,
  455. reverse=None,
  456. emit_unk_piece=None,
  457. nbest_size=None):
  458. """NBestEncode text input to segmented ids or tokens.
  459. Args:
  460. input: input string. accepsts list of string.
  461. out_type: output type. int or str.
  462. add_bos: Add <s> to the result (Default = false)
  463. add_eos: Add </s> to the result (Default = false) <s>/</s> is added after reversing (if enabled).
  464. reverse: Reverses the tokenized sequence (Default = false)
  465. emit_unk_piece: Emits the unk literal string (Default = false)
  466. nbest_size: nbest size
  467. """
  468. if out_type is None:
  469. out_type = self._out_type
  470. if add_bos is None:
  471. add_bos = self._add_bos
  472. if add_eos is None:
  473. add_eos = self._add_eos
  474. if reverse is None:
  475. reverse = self._reverse
  476. if emit_unk_piece is None:
  477. emit_unk_piece = self._emit_unk_piece
  478. if nbest_size is None:
  479. nbest_size = self._nbest_size
  480. if nbest_size <= 0:
  481. nbest_size=1
  482. def _encode(text):
  483. if out_type is int:
  484. return self._NBestEncodeAsIds(text, nbest_size,
  485. add_bos, add_eos, reverse, emit_unk_piece)
  486. if out_type is str:
  487. return self._NBestEncodeAsPieces(text, nbest_size,
  488. add_bos, add_eos, reverse, emit_unk_piece)
  489. if out_type == 'serialized_proto' or out_type == 'proto':
  490. return self._NBestEncodeAsSerializedProto(text, nbest_size,
  491. add_bos, add_eos, reverse, emit_unk_piece)
  492. if out_type == 'immutable_proto':
  493. return self._NBestEncodeAsImmutableProto(text, nbest_size,
  494. add_bos, add_eos, reverse, emit_unk_piece)
  495. raise RuntimeError('unknown out_type')
  496. if type(input) is list:
  497. return [_encode(n) for n in input]
  498. return _encode(input)
  499. def NBestEncodeAsPieces(self, input, nbest_size=None, **kwargs):
  500. return self.NBestEncode(input=input, nbest_size=nbest_size,
  501. out_type=str, **kwargs)
  502. def NBestEncodeAsIds(self, input, nbest_size=None, **kwargs):
  503. return self.NBestEncode(input=input, nbest_size=nbest_size,
  504. out_type=int, **kwargs)
  505. def NBestEncodeAsSerializedProto(self, input, nbest_size=None, **kwargs):
  506. return self.NBestEncode(input=input, nbest_size=nbest_size,
  507. out_type='serialized_proto', **kwargs)
  508. def NBestEncodeAsImmutableProto(self, input, nbest_size=None, **kwargs):
  509. return self.NBestEncode(input=input, nbest_size=nbest_size,
  510. out_type='immutable_proto', **kwargs)
  511. def SampleEncodeAndScore(self,
  512. input,
  513. out_type=None,
  514. add_bos=None,
  515. add_eos=None,
  516. reverse=None,
  517. emit_unk_piece=None,
  518. num_samples=None,
  519. alpha=None,
  520. wor=None,
  521. include_best=None):
  522. """SampleEncodeAndScore text input to segmented ids or tokens.
  523. Args:
  524. input: input string. accepsts list of string.
  525. out_type: output type. int or str or 'serialized_proto' or 'immutable_proto'
  526. add_bos: Add <s> to the result (Default = false)
  527. add_eos: Add </s> to the result (Default = false) <s>/</s> is added after reversing (if enabled).
  528. reverse: Reverses the tokenized sequence (Default = false)
  529. emit_unk_piece: Emits the unk literal string (Default = false)
  530. num_samples: How many samples to return (Default = 1)
  531. alpha: inverse temperature for sampling
  532. wor: whether to sample without replacement (Default = false)
  533. include_best: whether to include the best tokenization, requires wor=True (Default = false)
  534. """
  535. if out_type is None:
  536. out_type = self._out_type
  537. if add_bos is None:
  538. add_bos = self._add_bos
  539. if add_eos is None:
  540. add_eos = self._add_eos
  541. if reverse is None:
  542. reverse = self._reverse
  543. if emit_unk_piece is None:
  544. emit_unk_piece = self._emit_unk_piece
  545. if num_samples is None:
  546. num_samples = 1
  547. if alpha is None:
  548. alpha = 1.
  549. if wor is None:
  550. wor = False
  551. if include_best is None:
  552. include_best = False
  553. if num_samples <= 0:
  554. raise RuntimeError('num_examples must be positive')
  555. if include_best and not wor:
  556. raise RuntimeError('When include_best is True, We must specify "wor = True".')
  557. def _encode(text):
  558. if out_type is int:
  559. return self._SampleEncodeAndScoreAsIds(text, num_samples, alpha, wor, include_best,
  560. add_bos, add_eos, reverse, emit_unk_piece)
  561. if out_type is str:
  562. return self._SampleEncodeAndScoreAsPieces(text, num_samples, alpha, wor, include_best,
  563. add_bos, add_eos, reverse, emit_unk_piece)
  564. if out_type == 'serialized_proto' or out_type == 'proto':
  565. return self._SampleEncodeAndScoreAsSerializedProto(text, num_samples, alpha, wor, include_best,
  566. add_bos, add_eos, reverse, emit_unk_piece)
  567. if out_type == 'immutable_proto':
  568. return self._SampleEncodeAndScoreAsImmutableProto(text, num_samples, alpha, wor, include_best,
  569. add_bos, add_eos, reverse, emit_unk_piece)
  570. raise RuntimeError('unknown output type')
  571. if type(input) is list:
  572. return [_encode(n) for n in input]
  573. return _encode(input)
  574. def SampleEncodeAndScoreAsPieces(self, input, num_samples=None, alpha=None, **kwargs):
  575. return self.SampleEncodeAndScore(input=input, num_samples=num_samples, alpha=alpha,
  576. out_type=str, **kwargs)
  577. def SampleEncodeAndScoreAsIds(self, input, num_samples=None, alpha=None, **kwargs):
  578. return self.SampleEncodeAndScore(input=input, num_samples=num_samples, alpha=alpha,
  579. out_type=int, **kwargs)
  580. def SampleEncodeAndScoreAsSerializedProto(self, input, num_samples=None, alpha=None, **kwargs):
  581. return self.SampleEncodeAndScore(input=input, num_samples=num_samples, alpha=alpha,
  582. out_type='serialized_proto', **kwargs)
  583. def SampleEncodeAndScoreAsImmutableProto(self, input, num_samples=None, alpha=None, **kwargs):
  584. return self.SampleEncodeAndScore(input=input, num_samples=num_samples, alpha=alpha,
  585. out_type='immutable_proto', **kwargs)
  586. def Decode(self, input, out_type=str, num_threads=None):
  587. """Decode processed id or token sequences.
  588. Args:
  589. out_type: output type. str, bytes or 'serialized_proto' or 'immutable_proto' (Default = str)
  590. num_threads: the number of threads used in the batch processing (Default = -1).
  591. """
  592. if num_threads is None:
  593. num_threads = self._num_threads
  594. if num_threads is None or type(num_threads) is not int:
  595. raise RuntimeError('num_threads must be int')
  596. if not input:
  597. return ''
  598. if out_type is str:
  599. if type(input) is int:
  600. return self._DecodeIds([input])
  601. if type(input) is str:
  602. return self._DecodePieces([input])
  603. if type(input) is list:
  604. if len(input) == 0 or type(input[0]) is int:
  605. return self._DecodeIds(input)
  606. if type(input[0]) is str:
  607. return self._DecodePieces(input)
  608. if type(input[0]) is list:
  609. if len(input[0]) == 0 or type(input[0][0]) is int:
  610. return self._DecodeIdsBatch(input, num_threads)
  611. if type(input[0][0]) is str:
  612. return self._DecodePiecesBatch(input, num_threads)
  613. if out_type is bytes:
  614. if type(input) is int:
  615. return self._DecodeIdsAsBytes([input])
  616. if type(input) is str:
  617. return self._DecodePieces([input])
  618. if type(input) is list:
  619. if len(input) == 0 or type(input[0]) is int:
  620. return self._DecodeIdsAsBytes(input)
  621. if type(input[0]) is str:
  622. return self._DecodePieces(input)
  623. if type(input[0]) is list:
  624. if len(input[0]) == 0 or type(input[0][0]) is int:
  625. return self._DecodeIdsAsBytesBatch(input, num_threads)
  626. if type(input[0][0]) is str:
  627. return self._DecodePiecesBatch(input, num_threads)
  628. if out_type == 'serialized_proto':
  629. if type(input) is int:
  630. return self._DecodeIdsAsSerializedProto([input])
  631. if type(input) is str:
  632. return self._DecodePiecesAsSerializedProto([input])
  633. if type(input) is list:
  634. if len(input) == 0 or type(input[0]) is int:
  635. return self._DecodeIdsAsSerializedProto(input)
  636. if type(input[0]) is str:
  637. return self._DecodePiecesAsSerializedProto(input)
  638. if type(input[0]) is list:
  639. if len(input[0]) == 0 or type(input[0][0]) is int:
  640. return self._DecodeIdsAsSerializedProtoBatch(input, num_threads)
  641. if type(input[0][0]) is str:
  642. return self._DecodePiecesAsSerializedProtoBatch(input, num_threads)
  643. if out_type == 'immutable_proto':
  644. if type(input) is int:
  645. return self._DecodeIdsAsImmutableProto([input])
  646. if type(input) is str:
  647. return self._DecodePiecesAsImmutableProto([input])
  648. if type(input) is list:
  649. if len(input) == 0 or type(input[0]) is int:
  650. return self._DecodeIdsAsImmutableProto(input)
  651. if type(input[0]) is str:
  652. return self._DecodePiecesAsImmutableProto(input)
  653. if type(input[0]) is list:
  654. if len(input[0]) == 0 or type(input[0][0]) is int:
  655. return self._DecodeIdsAsImmutableProtoBatch(input, num_threads)
  656. if type(input[0][0]) is str:
  657. return self._DecodePiecesAsImmutableProtoBatch(input, num_threads)
  658. raise RuntimeError('unknown output or input type')
  659. return None
  660. def DecodePieces(self, input, out_type=str, **kwargs):
  661. return self.Decode(input=input, out_type=out_type, **kwargs)
  662. def DecodeIds(self, input, out_type=str, **kwargs):
  663. return self.Decode(input=input, out_type=out_type, **kwargs)
  664. def DecodePiecesAsSerializedProto(self, input, out_type='serialized_proto', **kwargs):
  665. return self.Decode(input=input, out_type=out_type, **kwargs)
  666. def DecodeIdsAsSerializedProto(self, input, out_type='serialized_proto', **kwargs):
  667. return self.Decode(input=input, out_type=out_type, **kwargs)
  668. def DecodePiecesAsImmutableProto(self, input, out_type='immutable_proto', **kwargs):
  669. return self.Decode(input=input, out_type=out_type, **kwargs)
  670. def DecodeIdsAsImmutableProto(self, input, out_type='immutable_proto', **kwargs):
  671. return self.Decode(input=input, out_type=out_type, **kwargs)
  672. def CalculateEntropy(self, input, alpha, num_threads=None):
  673. """Calculate sentence entropy"""
  674. if type(input) is list:
  675. if num_threads is None:
  676. num_threads = self._num_threads
  677. if num_threads is None or type(num_threads) is not int:
  678. raise RuntimeError('num_threads must be int')
  679. return self._CalculateEntropyBatch(input, alpha, num_threads)
  680. return self._CalculateEntropy(input, alpha)
  681. def Normalize(self, input, with_offsets=None):
  682. def _normalize(text):
  683. if with_offsets:
  684. return self._NormalizeWithOffsets(text)
  685. return self._Normalize(text)
  686. if type(input) is list:
  687. return [_normalize(x) for x in input]
  688. return _normalize(input)
  689. def OverrideNormalizerSpec(self, **kwargs):
  690. new_kwargs = {}
  691. for key, value in kwargs.items():
  692. new_kwargs[key] = str(value)
  693. return self._OverrideNormalizerSpec(new_kwargs)
  694. def piece_size(self):
  695. return self.GetPieceSize()
  696. def vocab_size(self):
  697. return self.GetPieceSize()
  698. def __getstate__(self):
  699. return self.serialized_model_proto()
  700. def __setstate__(self, serialized_model_proto):
  701. self.__init__()
  702. self.LoadFromSerializedProto(serialized_model_proto)
  703. def __len__(self):
  704. return self.GetPieceSize()
  705. def __getitem__(self, piece):
  706. return self.PieceToId(piece)
  707. def Load(self, model_file=None, model_proto=None):
  708. """Overwride SentencePieceProcessor.Load to support both model_file and model_proto.
  709. Args:
  710. model_file: The sentencepiece model file path.
  711. model_proto: The sentencepiece model serialized proto. Either `model_file`
  712. or `model_proto` must be set.
  713. """
  714. if model_file and model_proto:
  715. raise RuntimeError('model_file and model_proto must be exclusive.')
  716. if model_proto:
  717. return self.LoadFromSerializedProto(model_proto)
  718. return self.LoadFromFile(model_file)
  719. # Register SentencePieceProcessor in _sentencepiece:
  720. _sentencepiece.SentencePieceProcessor_swigregister(SentencePieceProcessor)
  721. def SetRandomGeneratorSeed(seed):
  722. return _sentencepiece.SetRandomGeneratorSeed(seed)
  723. def SetMinLogLevel(v):
  724. return _sentencepiece.SetMinLogLevel(v)
  725. class SentencePieceTrainer(object):
  726. thisown = property(lambda x: x.this.own(), lambda x, v: x.this.own(v), doc="The membership flag")
  727. def __init__(self, *args, **kwargs):
  728. raise AttributeError("No constructor defined")
  729. __repr__ = _swig_repr
  730. @staticmethod
  731. def _TrainFromString(arg):
  732. return _sentencepiece.SentencePieceTrainer__TrainFromString(arg)
  733. @staticmethod
  734. def _TrainFromMap(args):
  735. return _sentencepiece.SentencePieceTrainer__TrainFromMap(args)
  736. @staticmethod
  737. def _TrainFromMap2(args, iter):
  738. return _sentencepiece.SentencePieceTrainer__TrainFromMap2(args, iter)
  739. @staticmethod
  740. def _TrainFromMap3(args):
  741. return _sentencepiece.SentencePieceTrainer__TrainFromMap3(args)
  742. @staticmethod
  743. def _TrainFromMap4(args, iter):
  744. return _sentencepiece.SentencePieceTrainer__TrainFromMap4(args, iter)
  745. @staticmethod
  746. def _Train(arg=None, **kwargs):
  747. """Train Sentencepiece model. Accept both kwargs and legacy string arg."""
  748. if arg is not None and type(arg) is str:
  749. return SentencePieceTrainer._TrainFromString(arg)
  750. def _encode(value):
  751. """Encode value to CSV.."""
  752. if type(value) is list:
  753. if sys.version_info[0] == 3:
  754. f = StringIO()
  755. else:
  756. f = BytesIO()
  757. writer = csv.writer(f, lineterminator='')
  758. writer.writerow([str(v) for v in value])
  759. return f.getvalue()
  760. else:
  761. return str(value)
  762. sentence_iterator = None
  763. model_writer = None
  764. new_kwargs = {}
  765. for key, value in kwargs.items():
  766. if key in ['sentence_iterator', 'sentence_reader']:
  767. sentence_iterator = value
  768. elif key in ['model_writer']:
  769. model_writer = value
  770. else:
  771. new_kwargs[key] = _encode(value)
  772. if model_writer:
  773. if sentence_iterator:
  774. model_proto = SentencePieceTrainer._TrainFromMap4(new_kwargs,
  775. sentence_iterator)
  776. else:
  777. model_proto = SentencePieceTrainer._TrainFromMap3(new_kwargs)
  778. model_writer.write(model_proto)
  779. else:
  780. if sentence_iterator:
  781. return SentencePieceTrainer._TrainFromMap2(new_kwargs, sentence_iterator)
  782. else:
  783. return SentencePieceTrainer._TrainFromMap(new_kwargs)
  784. return None
  785. @staticmethod
  786. def Train(arg=None, logstream=None, **kwargs):
  787. with _LogStream(ostream=logstream):
  788. SentencePieceTrainer._Train(arg=arg, **kwargs)
  789. # Register SentencePieceTrainer in _sentencepiece:
  790. _sentencepiece.SentencePieceTrainer_swigregister(SentencePieceTrainer)
  791. class SentencePieceNormalizer(object):
  792. thisown = property(lambda x: x.this.own(), lambda x, v: x.this.own(v), doc="The membership flag")
  793. __repr__ = _swig_repr
  794. def __init__(self):
  795. _sentencepiece.SentencePieceNormalizer_swiginit(self, _sentencepiece.new_SentencePieceNormalizer())
  796. __swig_destroy__ = _sentencepiece.delete_SentencePieceNormalizer
  797. def LoadFromSerializedProto(self, serialized):
  798. return _sentencepiece.SentencePieceNormalizer_LoadFromSerializedProto(self, serialized)
  799. def LoadFromRuleTSV(self, filename):
  800. return _sentencepiece.SentencePieceNormalizer_LoadFromRuleTSV(self, filename)
  801. def LoadFromRuleName(self, name):
  802. return _sentencepiece.SentencePieceNormalizer_LoadFromRuleName(self, name)
  803. def serialized_model_proto(self):
  804. return _sentencepiece.SentencePieceNormalizer_serialized_model_proto(self)
  805. def LoadFromFile(self, arg):
  806. return _sentencepiece.SentencePieceNormalizer_LoadFromFile(self, arg)
  807. def _Normalize(self, text):
  808. return _sentencepiece.SentencePieceNormalizer__Normalize(self, text)
  809. def _NormalizeWithOffsets(self, text):
  810. return _sentencepiece.SentencePieceNormalizer__NormalizeWithOffsets(self, text)
  811. def _SetProtoField(self, name, value):
  812. return _sentencepiece.SentencePieceNormalizer__SetProtoField(self, name, value)
  813. def Init(self,
  814. model_file=None,
  815. model_proto=None,
  816. rule_tsv=None,
  817. rule_name=None,
  818. add_dummy_prefix=False,
  819. escape_whitespaces=False,
  820. remove_extra_whitespaces=False):
  821. """Initialzie sentencePieceNormalizer.
  822. Args:
  823. model_file: The sentencepiece model file path.
  824. model_proto: The sentencepiece model serialized proto.
  825. rule_tsv: The normalization rule file in TSV format.
  826. rule_name: Pre-defined normalization name.
  827. add_dummy_prefix: add dummy prefix.
  828. escape_whitespaces: escape whitespaces.
  829. remove_extra_whitespaces: remove extra whitespaces.
  830. """
  831. _sentencepiece_normalizer_init_native(self)
  832. if model_file:
  833. status = self.LoadFromFile(model_file)
  834. elif model_proto:
  835. status = self.LoadFromSerializedProto(model_proto)
  836. elif rule_tsv:
  837. status = self.LoadFromRuleTSV(rule_tsv)
  838. elif rule_name:
  839. status = self.LoadFromRuleName(rule_name)
  840. else:
  841. raise RuntimeError('no model is specified')
  842. if status:
  843. self._SetProtoField('add_dummy_prefix', add_dummy_prefix)
  844. self._SetProtoField('escape_whitespaces', escape_whitespaces)
  845. self._SetProtoField('remove_extra_whitespaces', remove_extra_whitespaces)
  846. def Normalize(self, input, with_offsets=None):
  847. def _normalize(text):
  848. if with_offsets:
  849. return self._NormalizeWithOffsets(text)
  850. return self._Normalize(text)
  851. if type(input) is list:
  852. return [_normalize(x) for x in input]
  853. return _normalize(input)
  854. def __getstate__(self):
  855. return self.serialized_model_proto()
  856. def __setstate__(self, serialized_model_proto):
  857. self.__init__()
  858. self.LoadFromSerializedProto(serialized_model_proto)
  859. # Register SentencePieceNormalizer in _sentencepiece:
  860. _sentencepiece.SentencePieceNormalizer_swigregister(SentencePieceNormalizer)
  861. def SetDataDir(data_dir):
  862. return _sentencepiece.SetDataDir(data_dir)
  863. import re
  864. import csv
  865. import sys
  866. import os
  867. import importlib.resources
  868. from io import StringIO
  869. from io import BytesIO
  870. def _add_snake_case(classname):
  871. """Added snake_cased method from CammelCased method."""
  872. snake_map = {}
  873. for k, v in classname.__dict__.items():
  874. if re.match(r'^[A-Z]+', k):
  875. snake = re.sub(r'(?<!^)(?=[A-Z])', '_',
  876. k).lower().replace('n_best', 'nbest')
  877. snake_map[snake] = v
  878. for k, v in snake_map.items():
  879. setattr(classname, k, v)
  880. def _batchnize(classname, name):
  881. """Enables batch request for the method classname.name."""
  882. func = getattr(classname, name, None)
  883. def _func(v, n):
  884. if type(n) is int and (n < 0 or n >= v.piece_size()):
  885. raise IndexError('piece id is out of range.')
  886. return func(v, n)
  887. def _batched_func(self, arg):
  888. if type(arg) is list:
  889. return [_func(self, n) for n in arg]
  890. else:
  891. return _func(self, arg)
  892. setattr(classname, name, _batched_func)
  893. _sentencepiece_processor_init_native = SentencePieceProcessor.__init__
  894. _sentencepiece_normalizer_init_native = SentencePieceNormalizer.__init__
  895. setattr(SentencePieceProcessor, '__init__', SentencePieceProcessor.Init)
  896. setattr(SentencePieceNormalizer, '__init__', SentencePieceNormalizer.Init)
  897. SentencePieceProcessor.Tokenize = SentencePieceProcessor.Encode
  898. SentencePieceProcessor.Detokenize = SentencePieceProcessor.Decode
  899. for m in [
  900. 'PieceToId', 'IdToPiece', 'GetScore', 'IsUnknown', 'IsControl', 'IsUnused',
  901. 'IsByte'
  902. ]:
  903. _batchnize(SentencePieceProcessor, m)
  904. _add_snake_case(SentencePieceProcessor)
  905. _add_snake_case(SentencePieceTrainer)
  906. _add_snake_case(SentencePieceNormalizer)
  907. set_random_generator_seed = SetRandomGeneratorSeed
  908. set_min_log_level = SetMinLogLevel
  909. from ._version import __version__
  910. SetDataDir(os.path.join(str(importlib.resources.files('sentencepiece')), 'package_data'))
  911. class _LogStream(object):
  912. def __init__(self, ostream=None):
  913. self.ostream = ostream
  914. if self.ostream is not None:
  915. self.orig_stream_fileno = sys.stderr.fileno()
  916. def __enter__(self):
  917. if self.ostream is not None:
  918. self.orig_stream_dup = os.dup(self.orig_stream_fileno)
  919. os.dup2(self.ostream.fileno(), self.orig_stream_fileno)
  920. def __exit__(self, type, value, traceback):
  921. if self.ostream is not None:
  922. os.close(self.orig_stream_fileno)
  923. os.dup2(self.orig_stream_dup, self.orig_stream_fileno)
  924. os.close(self.orig_stream_dup)
  925. self.ostream.close()