__init__.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349
  1. # Copyright 2022 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import TYPE_CHECKING
  15. from ..utils import OptionalDependencyNotAvailable, _LazyModule, is_flax_available, is_tf_available, is_torch_available
  16. _import_structure = {
  17. "configuration_utils": [
  18. "BaseWatermarkingConfig",
  19. "CompileConfig",
  20. "GenerationConfig",
  21. "GenerationMode",
  22. "SynthIDTextWatermarkingConfig",
  23. "WatermarkingConfig",
  24. ],
  25. "streamers": ["AsyncTextIteratorStreamer", "BaseStreamer", "TextIteratorStreamer", "TextStreamer"],
  26. }
  27. try:
  28. if not is_torch_available():
  29. raise OptionalDependencyNotAvailable()
  30. except OptionalDependencyNotAvailable:
  31. pass
  32. else:
  33. _import_structure["beam_constraints"] = [
  34. "Constraint",
  35. "ConstraintListState",
  36. "DisjunctiveConstraint",
  37. "PhrasalConstraint",
  38. ]
  39. _import_structure["beam_search"] = [
  40. "BeamHypotheses",
  41. "BeamScorer",
  42. "ConstrainedBeamSearchScorer",
  43. ]
  44. _import_structure["candidate_generator"] = [
  45. "AssistedCandidateGenerator",
  46. "CandidateGenerator",
  47. "EarlyExitCandidateGenerator",
  48. "PromptLookupCandidateGenerator",
  49. ]
  50. _import_structure["logits_process"] = [
  51. "AlternatingCodebooksLogitsProcessor",
  52. "ClassifierFreeGuidanceLogitsProcessor",
  53. "EncoderNoRepeatNGramLogitsProcessor",
  54. "EncoderRepetitionPenaltyLogitsProcessor",
  55. "EpsilonLogitsWarper",
  56. "EtaLogitsWarper",
  57. "ExponentialDecayLengthPenalty",
  58. "ForcedBOSTokenLogitsProcessor",
  59. "ForcedEOSTokenLogitsProcessor",
  60. "InfNanRemoveLogitsProcessor",
  61. "LogitNormalization",
  62. "LogitsProcessor",
  63. "LogitsProcessorList",
  64. "MinLengthLogitsProcessor",
  65. "MinNewTokensLengthLogitsProcessor",
  66. "MinPLogitsWarper",
  67. "NoBadWordsLogitsProcessor",
  68. "NoRepeatNGramLogitsProcessor",
  69. "PrefixConstrainedLogitsProcessor",
  70. "RepetitionPenaltyLogitsProcessor",
  71. "SequenceBiasLogitsProcessor",
  72. "SuppressTokensLogitsProcessor",
  73. "SuppressTokensAtBeginLogitsProcessor",
  74. "SynthIDTextWatermarkLogitsProcessor",
  75. "TemperatureLogitsWarper",
  76. "TopKLogitsWarper",
  77. "TopPLogitsWarper",
  78. "TypicalLogitsWarper",
  79. "UnbatchedClassifierFreeGuidanceLogitsProcessor",
  80. "WhisperTimeStampLogitsProcessor",
  81. "WatermarkLogitsProcessor",
  82. ]
  83. _import_structure["stopping_criteria"] = [
  84. "MaxLengthCriteria",
  85. "MaxTimeCriteria",
  86. "ConfidenceCriteria",
  87. "EosTokenCriteria",
  88. "StoppingCriteria",
  89. "StoppingCriteriaList",
  90. "validate_stopping_criteria",
  91. "StopStringCriteria",
  92. ]
  93. _import_structure["continuous_batching"] = [
  94. "ContinuousMixin",
  95. ]
  96. _import_structure["utils"] = [
  97. "GenerationMixin",
  98. "GreedySearchEncoderDecoderOutput",
  99. "GreedySearchDecoderOnlyOutput",
  100. "SampleEncoderDecoderOutput",
  101. "SampleDecoderOnlyOutput",
  102. "BeamSearchEncoderDecoderOutput",
  103. "BeamSearchDecoderOnlyOutput",
  104. "BeamSampleEncoderDecoderOutput",
  105. "BeamSampleDecoderOnlyOutput",
  106. "ContrastiveSearchEncoderDecoderOutput",
  107. "ContrastiveSearchDecoderOnlyOutput",
  108. "GenerateBeamDecoderOnlyOutput",
  109. "GenerateBeamEncoderDecoderOutput",
  110. "GenerateDecoderOnlyOutput",
  111. "GenerateEncoderDecoderOutput",
  112. ]
  113. _import_structure["watermarking"] = [
  114. "WatermarkDetector",
  115. "WatermarkDetectorOutput",
  116. "BayesianDetectorModel",
  117. "BayesianDetectorConfig",
  118. "SynthIDTextWatermarkDetector",
  119. ]
  120. try:
  121. if not is_tf_available():
  122. raise OptionalDependencyNotAvailable()
  123. except OptionalDependencyNotAvailable:
  124. pass
  125. else:
  126. _import_structure["tf_logits_process"] = [
  127. "TFForcedBOSTokenLogitsProcessor",
  128. "TFForcedEOSTokenLogitsProcessor",
  129. "TFForceTokensLogitsProcessor",
  130. "TFLogitsProcessor",
  131. "TFLogitsProcessorList",
  132. "TFLogitsWarper",
  133. "TFMinLengthLogitsProcessor",
  134. "TFNoBadWordsLogitsProcessor",
  135. "TFNoRepeatNGramLogitsProcessor",
  136. "TFRepetitionPenaltyLogitsProcessor",
  137. "TFSuppressTokensAtBeginLogitsProcessor",
  138. "TFSuppressTokensLogitsProcessor",
  139. "TFTemperatureLogitsWarper",
  140. "TFTopKLogitsWarper",
  141. "TFTopPLogitsWarper",
  142. ]
  143. _import_structure["tf_utils"] = [
  144. "TFGenerationMixin",
  145. "TFGreedySearchDecoderOnlyOutput",
  146. "TFGreedySearchEncoderDecoderOutput",
  147. "TFSampleEncoderDecoderOutput",
  148. "TFSampleDecoderOnlyOutput",
  149. "TFBeamSearchEncoderDecoderOutput",
  150. "TFBeamSearchDecoderOnlyOutput",
  151. "TFBeamSampleEncoderDecoderOutput",
  152. "TFBeamSampleDecoderOnlyOutput",
  153. "TFContrastiveSearchEncoderDecoderOutput",
  154. "TFContrastiveSearchDecoderOnlyOutput",
  155. ]
  156. try:
  157. if not is_flax_available():
  158. raise OptionalDependencyNotAvailable()
  159. except OptionalDependencyNotAvailable:
  160. pass
  161. else:
  162. _import_structure["flax_logits_process"] = [
  163. "FlaxForcedBOSTokenLogitsProcessor",
  164. "FlaxForcedEOSTokenLogitsProcessor",
  165. "FlaxForceTokensLogitsProcessor",
  166. "FlaxLogitsProcessor",
  167. "FlaxLogitsProcessorList",
  168. "FlaxLogitsWarper",
  169. "FlaxMinLengthLogitsProcessor",
  170. "FlaxSuppressTokensAtBeginLogitsProcessor",
  171. "FlaxSuppressTokensLogitsProcessor",
  172. "FlaxTemperatureLogitsWarper",
  173. "FlaxTopKLogitsWarper",
  174. "FlaxTopPLogitsWarper",
  175. "FlaxWhisperTimeStampLogitsProcessor",
  176. "FlaxNoRepeatNGramLogitsProcessor",
  177. ]
  178. _import_structure["flax_utils"] = [
  179. "FlaxGenerationMixin",
  180. "FlaxGreedySearchOutput",
  181. "FlaxSampleOutput",
  182. "FlaxBeamSearchOutput",
  183. ]
  184. if TYPE_CHECKING:
  185. from .configuration_utils import (
  186. BaseWatermarkingConfig,
  187. CompileConfig,
  188. GenerationConfig,
  189. GenerationMode,
  190. SynthIDTextWatermarkingConfig,
  191. WatermarkingConfig,
  192. )
  193. from .streamers import AsyncTextIteratorStreamer, BaseStreamer, TextIteratorStreamer, TextStreamer
  194. try:
  195. if not is_torch_available():
  196. raise OptionalDependencyNotAvailable()
  197. except OptionalDependencyNotAvailable:
  198. pass
  199. else:
  200. from .beam_constraints import Constraint, ConstraintListState, DisjunctiveConstraint, PhrasalConstraint
  201. from .beam_search import BeamHypotheses, BeamScorer, ConstrainedBeamSearchScorer
  202. from .candidate_generator import (
  203. AssistedCandidateGenerator,
  204. CandidateGenerator,
  205. EarlyExitCandidateGenerator,
  206. PromptLookupCandidateGenerator,
  207. )
  208. from .continuous_batching import ContinuousMixin
  209. from .logits_process import (
  210. AlternatingCodebooksLogitsProcessor,
  211. ClassifierFreeGuidanceLogitsProcessor,
  212. EncoderNoRepeatNGramLogitsProcessor,
  213. EncoderRepetitionPenaltyLogitsProcessor,
  214. EpsilonLogitsWarper,
  215. EtaLogitsWarper,
  216. ExponentialDecayLengthPenalty,
  217. ForcedBOSTokenLogitsProcessor,
  218. ForcedEOSTokenLogitsProcessor,
  219. InfNanRemoveLogitsProcessor,
  220. LogitNormalization,
  221. LogitsProcessor,
  222. LogitsProcessorList,
  223. MinLengthLogitsProcessor,
  224. MinNewTokensLengthLogitsProcessor,
  225. MinPLogitsWarper,
  226. NoBadWordsLogitsProcessor,
  227. NoRepeatNGramLogitsProcessor,
  228. PrefixConstrainedLogitsProcessor,
  229. RepetitionPenaltyLogitsProcessor,
  230. SequenceBiasLogitsProcessor,
  231. SuppressTokensAtBeginLogitsProcessor,
  232. SuppressTokensLogitsProcessor,
  233. SynthIDTextWatermarkLogitsProcessor,
  234. TemperatureLogitsWarper,
  235. TopKLogitsWarper,
  236. TopPLogitsWarper,
  237. TypicalLogitsWarper,
  238. UnbatchedClassifierFreeGuidanceLogitsProcessor,
  239. WatermarkLogitsProcessor,
  240. WhisperTimeStampLogitsProcessor,
  241. )
  242. from .stopping_criteria import (
  243. ConfidenceCriteria,
  244. EosTokenCriteria,
  245. MaxLengthCriteria,
  246. MaxTimeCriteria,
  247. StoppingCriteria,
  248. StoppingCriteriaList,
  249. StopStringCriteria,
  250. validate_stopping_criteria,
  251. )
  252. from .utils import (
  253. BeamSampleDecoderOnlyOutput,
  254. BeamSampleEncoderDecoderOutput,
  255. BeamSearchDecoderOnlyOutput,
  256. BeamSearchEncoderDecoderOutput,
  257. ContrastiveSearchDecoderOnlyOutput,
  258. ContrastiveSearchEncoderDecoderOutput,
  259. GenerateBeamDecoderOnlyOutput,
  260. GenerateBeamEncoderDecoderOutput,
  261. GenerateDecoderOnlyOutput,
  262. GenerateEncoderDecoderOutput,
  263. GenerationMixin,
  264. GreedySearchDecoderOnlyOutput,
  265. GreedySearchEncoderDecoderOutput,
  266. SampleDecoderOnlyOutput,
  267. SampleEncoderDecoderOutput,
  268. )
  269. from .watermarking import (
  270. BayesianDetectorConfig,
  271. BayesianDetectorModel,
  272. SynthIDTextWatermarkDetector,
  273. WatermarkDetector,
  274. WatermarkDetectorOutput,
  275. )
  276. try:
  277. if not is_tf_available():
  278. raise OptionalDependencyNotAvailable()
  279. except OptionalDependencyNotAvailable:
  280. pass
  281. else:
  282. from .tf_logits_process import (
  283. TFForcedBOSTokenLogitsProcessor,
  284. TFForcedEOSTokenLogitsProcessor,
  285. TFForceTokensLogitsProcessor,
  286. TFLogitsProcessor,
  287. TFLogitsProcessorList,
  288. TFLogitsWarper,
  289. TFMinLengthLogitsProcessor,
  290. TFNoBadWordsLogitsProcessor,
  291. TFNoRepeatNGramLogitsProcessor,
  292. TFRepetitionPenaltyLogitsProcessor,
  293. TFSuppressTokensAtBeginLogitsProcessor,
  294. TFSuppressTokensLogitsProcessor,
  295. TFTemperatureLogitsWarper,
  296. TFTopKLogitsWarper,
  297. TFTopPLogitsWarper,
  298. )
  299. from .tf_utils import (
  300. TFBeamSampleDecoderOnlyOutput,
  301. TFBeamSampleEncoderDecoderOutput,
  302. TFBeamSearchDecoderOnlyOutput,
  303. TFBeamSearchEncoderDecoderOutput,
  304. TFContrastiveSearchDecoderOnlyOutput,
  305. TFContrastiveSearchEncoderDecoderOutput,
  306. TFGenerationMixin,
  307. TFGreedySearchDecoderOnlyOutput,
  308. TFGreedySearchEncoderDecoderOutput,
  309. TFSampleDecoderOnlyOutput,
  310. TFSampleEncoderDecoderOutput,
  311. )
  312. try:
  313. if not is_flax_available():
  314. raise OptionalDependencyNotAvailable()
  315. except OptionalDependencyNotAvailable:
  316. pass
  317. else:
  318. from .flax_logits_process import (
  319. FlaxForcedBOSTokenLogitsProcessor,
  320. FlaxForcedEOSTokenLogitsProcessor,
  321. FlaxForceTokensLogitsProcessor,
  322. FlaxLogitsProcessor,
  323. FlaxLogitsProcessorList,
  324. FlaxLogitsWarper,
  325. FlaxMinLengthLogitsProcessor,
  326. FlaxNoRepeatNGramLogitsProcessor,
  327. FlaxSuppressTokensAtBeginLogitsProcessor,
  328. FlaxSuppressTokensLogitsProcessor,
  329. FlaxTemperatureLogitsWarper,
  330. FlaxTopKLogitsWarper,
  331. FlaxTopPLogitsWarper,
  332. FlaxWhisperTimeStampLogitsProcessor,
  333. )
  334. from .flax_utils import FlaxBeamSearchOutput, FlaxGenerationMixin, FlaxGreedySearchOutput, FlaxSampleOutput
  335. else:
  336. import sys
  337. sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)