__init__.py 46 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975
  1. # Copyright 2020 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # When adding a new object to this init, remember to add it twice: once inside the `_import_structure` dictionary and
  15. # once inside the `if TYPE_CHECKING` branch. The `TYPE_CHECKING` should have import statements as usual, but they are
  16. # only there for type checking. The `_import_structure` is a dictionary submodule to list of object names, and is used
  17. # to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
  18. # in the namespace without actually importing anything (and especially none of the backends).
  19. __version__ = "4.57.3"
  20. from pathlib import Path
  21. from typing import TYPE_CHECKING
  22. # Check the dependencies satisfy the minimal versions required.
  23. from . import dependency_versions_check
  24. from .utils import (
  25. OptionalDependencyNotAvailable,
  26. _LazyModule,
  27. is_essentia_available,
  28. is_g2p_en_available,
  29. is_librosa_available,
  30. is_mistral_common_available,
  31. is_mlx_available,
  32. is_pretty_midi_available,
  33. )
  34. # Note: the following symbols are deliberately exported with `as`
  35. # so that mypy, pylint or other static linters can recognize them,
  36. # given that they are not exported using `__all__` in this file.
  37. from .utils import is_bitsandbytes_available as is_bitsandbytes_available
  38. from .utils import is_flax_available as is_flax_available
  39. from .utils import is_keras_nlp_available as is_keras_nlp_available
  40. from .utils import is_scipy_available as is_scipy_available
  41. from .utils import is_sentencepiece_available as is_sentencepiece_available
  42. from .utils import is_speech_available as is_speech_available
  43. from .utils import is_tensorflow_text_available as is_tensorflow_text_available
  44. from .utils import is_tf_available as is_tf_available
  45. from .utils import is_timm_available as is_timm_available
  46. from .utils import is_tokenizers_available as is_tokenizers_available
  47. from .utils import is_torch_available as is_torch_available
  48. from .utils import is_torchaudio_available as is_torchaudio_available
  49. from .utils import is_torchvision_available as is_torchvision_available
  50. from .utils import is_vision_available as is_vision_available
  51. from .utils import logging as logging
  52. from .utils.import_utils import define_import_structure
  53. logger = logging.get_logger(__name__) # pylint: disable=invalid-name
  54. # Base objects, independent of any specific backend
  55. _import_structure = {
  56. "audio_utils": [],
  57. "commands": [],
  58. "configuration_utils": ["PretrainedConfig"],
  59. "convert_graph_to_onnx": [],
  60. "convert_slow_tokenizers_checkpoints_to_fast": [],
  61. "convert_tf_hub_seq_to_seq_bert_to_pytorch": [],
  62. "data": [
  63. "DataProcessor",
  64. "InputExample",
  65. "InputFeatures",
  66. "SingleSentenceClassificationProcessor",
  67. "SquadExample",
  68. "SquadFeatures",
  69. "SquadV1Processor",
  70. "SquadV2Processor",
  71. "glue_compute_metrics",
  72. "glue_convert_examples_to_features",
  73. "glue_output_modes",
  74. "glue_processors",
  75. "glue_tasks_num_labels",
  76. "squad_convert_examples_to_features",
  77. "xnli_compute_metrics",
  78. "xnli_output_modes",
  79. "xnli_processors",
  80. "xnli_tasks_num_labels",
  81. ],
  82. "data.data_collator": [
  83. "DataCollator",
  84. "DataCollatorForLanguageModeling",
  85. "DataCollatorForMultipleChoice",
  86. "DataCollatorForPermutationLanguageModeling",
  87. "DataCollatorForSeq2Seq",
  88. "DataCollatorForSOP",
  89. "DataCollatorForTokenClassification",
  90. "DataCollatorForWholeWordMask",
  91. "DataCollatorWithFlattening",
  92. "DataCollatorWithPadding",
  93. "DefaultDataCollator",
  94. "default_data_collator",
  95. ],
  96. "data.metrics": [],
  97. "data.processors": [],
  98. "debug_utils": [],
  99. "dependency_versions_check": [],
  100. "dependency_versions_table": [],
  101. "dynamic_module_utils": [],
  102. "feature_extraction_sequence_utils": ["SequenceFeatureExtractor"],
  103. "feature_extraction_utils": ["BatchFeature", "FeatureExtractionMixin"],
  104. "file_utils": [],
  105. "generation": [
  106. "AsyncTextIteratorStreamer",
  107. "CompileConfig",
  108. "GenerationConfig",
  109. "TextIteratorStreamer",
  110. "TextStreamer",
  111. "WatermarkingConfig",
  112. ],
  113. "hf_argparser": ["HfArgumentParser"],
  114. "hyperparameter_search": [],
  115. "image_transforms": [],
  116. "integrations": [
  117. "is_clearml_available",
  118. "is_comet_available",
  119. "is_dvclive_available",
  120. "is_neptune_available",
  121. "is_optuna_available",
  122. "is_ray_available",
  123. "is_ray_tune_available",
  124. "is_sigopt_available",
  125. "is_swanlab_available",
  126. "is_tensorboard_available",
  127. "is_trackio_available",
  128. "is_wandb_available",
  129. ],
  130. "loss": [],
  131. "modelcard": ["ModelCard"],
  132. # Losses
  133. "modeling_tf_pytorch_utils": [
  134. "convert_tf_weight_name_to_pt_weight_name",
  135. "load_pytorch_checkpoint_in_tf2_model",
  136. "load_pytorch_model_in_tf2_model",
  137. "load_pytorch_weights_in_tf2_model",
  138. "load_tf2_checkpoint_in_pytorch_model",
  139. "load_tf2_model_in_pytorch_model",
  140. "load_tf2_weights_in_pytorch_model",
  141. ],
  142. # Models
  143. "onnx": [],
  144. "pipelines": [
  145. "AudioClassificationPipeline",
  146. "AutomaticSpeechRecognitionPipeline",
  147. "CsvPipelineDataFormat",
  148. "DepthEstimationPipeline",
  149. "DocumentQuestionAnsweringPipeline",
  150. "FeatureExtractionPipeline",
  151. "FillMaskPipeline",
  152. "ImageClassificationPipeline",
  153. "ImageFeatureExtractionPipeline",
  154. "ImageSegmentationPipeline",
  155. "ImageTextToTextPipeline",
  156. "ImageToImagePipeline",
  157. "ImageToTextPipeline",
  158. "JsonPipelineDataFormat",
  159. "KeypointMatchingPipeline",
  160. "MaskGenerationPipeline",
  161. "NerPipeline",
  162. "ObjectDetectionPipeline",
  163. "PipedPipelineDataFormat",
  164. "Pipeline",
  165. "PipelineDataFormat",
  166. "QuestionAnsweringPipeline",
  167. "SummarizationPipeline",
  168. "TableQuestionAnsweringPipeline",
  169. "Text2TextGenerationPipeline",
  170. "TextClassificationPipeline",
  171. "TextGenerationPipeline",
  172. "TextToAudioPipeline",
  173. "TokenClassificationPipeline",
  174. "TranslationPipeline",
  175. "VideoClassificationPipeline",
  176. "VisualQuestionAnsweringPipeline",
  177. "ZeroShotAudioClassificationPipeline",
  178. "ZeroShotClassificationPipeline",
  179. "ZeroShotImageClassificationPipeline",
  180. "ZeroShotObjectDetectionPipeline",
  181. "pipeline",
  182. ],
  183. "processing_utils": ["ProcessorMixin"],
  184. "quantizers": [],
  185. "testing_utils": [],
  186. "tokenization_utils": ["PreTrainedTokenizer"],
  187. "tokenization_utils_base": [
  188. "AddedToken",
  189. "BatchEncoding",
  190. "CharSpan",
  191. "PreTrainedTokenizerBase",
  192. "SpecialTokensMixin",
  193. "TokenSpan",
  194. ],
  195. "trainer_callback": [
  196. "DefaultFlowCallback",
  197. "EarlyStoppingCallback",
  198. "PrinterCallback",
  199. "ProgressCallback",
  200. "TrainerCallback",
  201. "TrainerControl",
  202. "TrainerState",
  203. ],
  204. "trainer_utils": [
  205. "EvalPrediction",
  206. "IntervalStrategy",
  207. "SchedulerType",
  208. "enable_full_determinism",
  209. "set_seed",
  210. ],
  211. "training_args": ["TrainingArguments"],
  212. "training_args_seq2seq": ["Seq2SeqTrainingArguments"],
  213. "training_args_tf": ["TFTrainingArguments"],
  214. "utils": [
  215. "CONFIG_NAME",
  216. "MODEL_CARD_NAME",
  217. "PYTORCH_PRETRAINED_BERT_CACHE",
  218. "PYTORCH_TRANSFORMERS_CACHE",
  219. "SPIECE_UNDERLINE",
  220. "TF2_WEIGHTS_NAME",
  221. "TF_WEIGHTS_NAME",
  222. "TRANSFORMERS_CACHE",
  223. "WEIGHTS_NAME",
  224. "TensorType",
  225. "add_end_docstrings",
  226. "add_start_docstrings",
  227. "is_apex_available",
  228. "is_av_available",
  229. "is_bitsandbytes_available",
  230. "is_datasets_available",
  231. "is_faiss_available",
  232. "is_flax_available",
  233. "is_keras_nlp_available",
  234. "is_matplotlib_available",
  235. "is_mlx_available",
  236. "is_phonemizer_available",
  237. "is_psutil_available",
  238. "is_py3nvml_available",
  239. "is_pyctcdecode_available",
  240. "is_sacremoses_available",
  241. "is_safetensors_available",
  242. "is_scipy_available",
  243. "is_sentencepiece_available",
  244. "is_sklearn_available",
  245. "is_speech_available",
  246. "is_tensorflow_text_available",
  247. "is_tf_available",
  248. "is_timm_available",
  249. "is_tokenizers_available",
  250. "is_torch_available",
  251. "is_torch_hpu_available",
  252. "is_torch_mlu_available",
  253. "is_torch_musa_available",
  254. "is_torch_neuroncore_available",
  255. "is_torch_npu_available",
  256. "is_torchvision_available",
  257. "is_torch_xla_available",
  258. "is_torch_xpu_available",
  259. "is_vision_available",
  260. "logging",
  261. ],
  262. "utils.quantization_config": [
  263. "AqlmConfig",
  264. "AutoRoundConfig",
  265. "AwqConfig",
  266. "BitNetQuantConfig",
  267. "BitsAndBytesConfig",
  268. "CompressedTensorsConfig",
  269. "EetqConfig",
  270. "FbgemmFp8Config",
  271. "FineGrainedFP8Config",
  272. "GPTQConfig",
  273. "HiggsConfig",
  274. "HqqConfig",
  275. "Mxfp4Config",
  276. "QuantoConfig",
  277. "QuarkConfig",
  278. "FPQuantConfig",
  279. "SpQRConfig",
  280. "TorchAoConfig",
  281. "VptqConfig",
  282. ],
  283. "video_utils": [],
  284. }
  285. # tokenizers-backed objects
  286. try:
  287. if not is_tokenizers_available():
  288. raise OptionalDependencyNotAvailable()
  289. except OptionalDependencyNotAvailable:
  290. from .utils import dummy_tokenizers_objects
  291. _import_structure["utils.dummy_tokenizers_objects"] = [
  292. name for name in dir(dummy_tokenizers_objects) if not name.startswith("_")
  293. ]
  294. else:
  295. # Fast tokenizers structure
  296. _import_structure["tokenization_utils_fast"] = ["PreTrainedTokenizerFast"]
  297. try:
  298. if not (is_sentencepiece_available() and is_tokenizers_available()):
  299. raise OptionalDependencyNotAvailable()
  300. except OptionalDependencyNotAvailable:
  301. from .utils import dummy_sentencepiece_and_tokenizers_objects
  302. _import_structure["utils.dummy_sentencepiece_and_tokenizers_objects"] = [
  303. name for name in dir(dummy_sentencepiece_and_tokenizers_objects) if not name.startswith("_")
  304. ]
  305. else:
  306. _import_structure["convert_slow_tokenizer"] = [
  307. "SLOW_TO_FAST_CONVERTERS",
  308. "convert_slow_tokenizer",
  309. ]
  310. try:
  311. if not (is_mistral_common_available()):
  312. raise OptionalDependencyNotAvailable()
  313. except OptionalDependencyNotAvailable:
  314. from .utils import dummy_mistral_common_objects
  315. _import_structure["utils.dummy_mistral_common_objects"] = [
  316. name for name in dir(dummy_mistral_common_objects) if not name.startswith("_")
  317. ]
  318. else:
  319. _import_structure["tokenization_mistral_common"] = ["MistralCommonTokenizer"]
  320. # Vision-specific objects
  321. try:
  322. if not is_vision_available():
  323. raise OptionalDependencyNotAvailable()
  324. except OptionalDependencyNotAvailable:
  325. from .utils import dummy_vision_objects
  326. _import_structure["utils.dummy_vision_objects"] = [
  327. name for name in dir(dummy_vision_objects) if not name.startswith("_")
  328. ]
  329. else:
  330. _import_structure["image_processing_base"] = ["ImageProcessingMixin"]
  331. _import_structure["image_processing_utils"] = ["BaseImageProcessor"]
  332. _import_structure["image_utils"] = ["ImageFeatureExtractionMixin"]
  333. try:
  334. if not is_torchvision_available():
  335. raise OptionalDependencyNotAvailable()
  336. except OptionalDependencyNotAvailable:
  337. from .utils import dummy_torchvision_objects
  338. _import_structure["utils.dummy_torchvision_objects"] = [
  339. name for name in dir(dummy_torchvision_objects) if not name.startswith("_")
  340. ]
  341. else:
  342. _import_structure["image_processing_utils_fast"] = ["BaseImageProcessorFast"]
  343. _import_structure["video_processing_utils"] = ["BaseVideoProcessor"]
  344. # PyTorch-backed objects
  345. try:
  346. if not is_torch_available():
  347. raise OptionalDependencyNotAvailable()
  348. except OptionalDependencyNotAvailable:
  349. from .utils import dummy_pt_objects
  350. _import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")]
  351. else:
  352. _import_structure["model_debugging_utils"] = [
  353. "model_addition_debugger_context",
  354. ]
  355. _import_structure["activations"] = []
  356. _import_structure["cache_utils"] = [
  357. "CacheLayerMixin",
  358. "DynamicLayer",
  359. "StaticLayer",
  360. "StaticSlidingWindowLayer",
  361. "SlidingWindowLayer",
  362. "ChunkedSlidingLayer",
  363. "QuantoQuantizedLayer",
  364. "HQQQuantizedLayer",
  365. "Cache",
  366. "DynamicCache",
  367. "EncoderDecoderCache",
  368. "HQQQuantizedCache",
  369. "HybridCache",
  370. "HybridChunkedCache",
  371. "OffloadedCache",
  372. "OffloadedStaticCache",
  373. "QuantizedCache",
  374. "QuantoQuantizedCache",
  375. "SinkCache",
  376. "SlidingWindowCache",
  377. "StaticCache",
  378. ]
  379. _import_structure["data.datasets"] = [
  380. "GlueDataset",
  381. "GlueDataTrainingArguments",
  382. "LineByLineTextDataset",
  383. "LineByLineWithRefDataset",
  384. "LineByLineWithSOPTextDataset",
  385. "SquadDataset",
  386. "SquadDataTrainingArguments",
  387. "TextDataset",
  388. "TextDatasetForNextSentencePrediction",
  389. ]
  390. _import_structure["generation"].extend(
  391. [
  392. "AlternatingCodebooksLogitsProcessor",
  393. "BayesianDetectorConfig",
  394. "BayesianDetectorModel",
  395. "BeamScorer",
  396. "ClassifierFreeGuidanceLogitsProcessor",
  397. "ConstrainedBeamSearchScorer",
  398. "Constraint",
  399. "ConstraintListState",
  400. "DisjunctiveConstraint",
  401. "EncoderNoRepeatNGramLogitsProcessor",
  402. "EncoderRepetitionPenaltyLogitsProcessor",
  403. "EosTokenCriteria",
  404. "EpsilonLogitsWarper",
  405. "EtaLogitsWarper",
  406. "ExponentialDecayLengthPenalty",
  407. "ForcedBOSTokenLogitsProcessor",
  408. "ForcedEOSTokenLogitsProcessor",
  409. "GenerationMixin",
  410. "InfNanRemoveLogitsProcessor",
  411. "LogitNormalization",
  412. "LogitsProcessor",
  413. "LogitsProcessorList",
  414. "MaxLengthCriteria",
  415. "MaxTimeCriteria",
  416. "MinLengthLogitsProcessor",
  417. "MinNewTokensLengthLogitsProcessor",
  418. "MinPLogitsWarper",
  419. "NoBadWordsLogitsProcessor",
  420. "NoRepeatNGramLogitsProcessor",
  421. "PhrasalConstraint",
  422. "PrefixConstrainedLogitsProcessor",
  423. "RepetitionPenaltyLogitsProcessor",
  424. "SequenceBiasLogitsProcessor",
  425. "StoppingCriteria",
  426. "StoppingCriteriaList",
  427. "StopStringCriteria",
  428. "SuppressTokensAtBeginLogitsProcessor",
  429. "SuppressTokensLogitsProcessor",
  430. "SynthIDTextWatermarkDetector",
  431. "SynthIDTextWatermarkingConfig",
  432. "SynthIDTextWatermarkLogitsProcessor",
  433. "TemperatureLogitsWarper",
  434. "TopKLogitsWarper",
  435. "TopPLogitsWarper",
  436. "TypicalLogitsWarper",
  437. "UnbatchedClassifierFreeGuidanceLogitsProcessor",
  438. "WatermarkDetector",
  439. "WatermarkLogitsProcessor",
  440. "WhisperTimeStampLogitsProcessor",
  441. ]
  442. )
  443. # PyTorch domain libraries integration
  444. _import_structure["integrations.executorch"] = [
  445. "TorchExportableModuleWithStaticCache",
  446. "convert_and_export_with_cache",
  447. ]
  448. _import_structure["modeling_flash_attention_utils"] = []
  449. _import_structure["modeling_layers"] = ["GradientCheckpointingLayer"]
  450. _import_structure["modeling_outputs"] = []
  451. _import_structure["modeling_rope_utils"] = ["ROPE_INIT_FUNCTIONS", "dynamic_rope_update"]
  452. _import_structure["modeling_utils"] = ["PreTrainedModel", "AttentionInterface"]
  453. _import_structure["masking_utils"] = ["AttentionMaskInterface"]
  454. _import_structure["optimization"] = [
  455. "Adafactor",
  456. "get_constant_schedule",
  457. "get_constant_schedule_with_warmup",
  458. "get_cosine_schedule_with_warmup",
  459. "get_cosine_with_hard_restarts_schedule_with_warmup",
  460. "get_cosine_with_min_lr_schedule_with_warmup",
  461. "get_cosine_with_min_lr_schedule_with_warmup_lr_rate",
  462. "get_inverse_sqrt_schedule",
  463. "get_linear_schedule_with_warmup",
  464. "get_polynomial_decay_schedule_with_warmup",
  465. "get_scheduler",
  466. "get_wsd_schedule",
  467. "get_reduce_on_plateau_schedule",
  468. ]
  469. _import_structure["pytorch_utils"] = [
  470. "Conv1D",
  471. "apply_chunking_to_forward",
  472. "prune_layer",
  473. "infer_device",
  474. ]
  475. _import_structure["sagemaker"] = []
  476. _import_structure["time_series_utils"] = []
  477. _import_structure["trainer"] = ["Trainer"]
  478. _import_structure["trainer_pt_utils"] = ["torch_distributed_zero_first"]
  479. _import_structure["trainer_seq2seq"] = ["Seq2SeqTrainer"]
  480. # TensorFlow-backed objects
  481. try:
  482. if not is_tf_available():
  483. raise OptionalDependencyNotAvailable()
  484. except OptionalDependencyNotAvailable:
  485. from .utils import dummy_tf_objects
  486. _import_structure["utils.dummy_tf_objects"] = [name for name in dir(dummy_tf_objects) if not name.startswith("_")]
  487. else:
  488. _import_structure["activations_tf"] = []
  489. _import_structure["generation"].extend(
  490. [
  491. "TFForcedBOSTokenLogitsProcessor",
  492. "TFForcedEOSTokenLogitsProcessor",
  493. "TFForceTokensLogitsProcessor",
  494. "TFGenerationMixin",
  495. "TFLogitsProcessor",
  496. "TFLogitsProcessorList",
  497. "TFLogitsWarper",
  498. "TFMinLengthLogitsProcessor",
  499. "TFNoBadWordsLogitsProcessor",
  500. "TFNoRepeatNGramLogitsProcessor",
  501. "TFRepetitionPenaltyLogitsProcessor",
  502. "TFSuppressTokensAtBeginLogitsProcessor",
  503. "TFSuppressTokensLogitsProcessor",
  504. "TFTemperatureLogitsWarper",
  505. "TFTopKLogitsWarper",
  506. "TFTopPLogitsWarper",
  507. ]
  508. )
  509. _import_structure["keras_callbacks"] = ["KerasMetricCallback", "PushToHubCallback"]
  510. _import_structure["modeling_tf_outputs"] = []
  511. _import_structure["modeling_tf_utils"] = [
  512. "TFPreTrainedModel",
  513. "TFSequenceSummary",
  514. "TFSharedEmbeddings",
  515. "shape_list",
  516. ]
  517. _import_structure["optimization_tf"] = [
  518. "AdamWeightDecay",
  519. "GradientAccumulator",
  520. "WarmUp",
  521. "create_optimizer",
  522. ]
  523. _import_structure["tf_utils"] = []
  524. # FLAX-backed objects
  525. try:
  526. if not is_flax_available():
  527. raise OptionalDependencyNotAvailable()
  528. except OptionalDependencyNotAvailable:
  529. from .utils import dummy_flax_objects
  530. _import_structure["utils.dummy_flax_objects"] = [
  531. name for name in dir(dummy_flax_objects) if not name.startswith("_")
  532. ]
  533. else:
  534. _import_structure["generation"].extend(
  535. [
  536. "FlaxForcedBOSTokenLogitsProcessor",
  537. "FlaxForcedEOSTokenLogitsProcessor",
  538. "FlaxForceTokensLogitsProcessor",
  539. "FlaxGenerationMixin",
  540. "FlaxLogitsProcessor",
  541. "FlaxLogitsProcessorList",
  542. "FlaxLogitsWarper",
  543. "FlaxMinLengthLogitsProcessor",
  544. "FlaxTemperatureLogitsWarper",
  545. "FlaxSuppressTokensAtBeginLogitsProcessor",
  546. "FlaxSuppressTokensLogitsProcessor",
  547. "FlaxTopKLogitsWarper",
  548. "FlaxTopPLogitsWarper",
  549. "FlaxWhisperTimeStampLogitsProcessor",
  550. ]
  551. )
  552. _import_structure["modeling_flax_outputs"] = []
  553. _import_structure["modeling_flax_utils"] = ["FlaxPreTrainedModel"]
  554. # Direct imports for type-checking
  555. if TYPE_CHECKING:
  556. # All modeling imports
  557. from .cache_utils import Cache as Cache
  558. from .cache_utils import ChunkedSlidingLayer as ChunkedSlidingLayer
  559. from .cache_utils import DynamicCache as DynamicCache
  560. from .cache_utils import DynamicLayer as DynamicLayer
  561. from .cache_utils import EncoderDecoderCache as EncoderDecoderCache
  562. from .cache_utils import HQQQuantizedCache as HQQQuantizedCache
  563. from .cache_utils import HQQQuantizedLayer as HQQQuantizedLayer
  564. from .cache_utils import HybridCache as HybridCache
  565. from .cache_utils import OffloadedCache as OffloadedCache
  566. from .cache_utils import OffloadedStaticCache as OffloadedStaticCache
  567. from .cache_utils import QuantizedCache as QuantizedCache
  568. from .cache_utils import QuantoQuantizedCache as QuantoQuantizedCache
  569. from .cache_utils import QuantoQuantizedLayer as QuantoQuantizedLayer
  570. from .cache_utils import SinkCache as SinkCache
  571. from .cache_utils import SlidingWindowCache as SlidingWindowCache
  572. from .cache_utils import SlidingWindowLayer as SlidingWindowLayer
  573. from .cache_utils import StaticCache as StaticCache
  574. from .cache_utils import StaticLayer as StaticLayer
  575. from .cache_utils import StaticSlidingWindowLayer as StaticSlidingWindowLayer
  576. from .configuration_utils import PretrainedConfig as PretrainedConfig
  577. from .convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS as SLOW_TO_FAST_CONVERTERS
  578. from .convert_slow_tokenizer import convert_slow_tokenizer as convert_slow_tokenizer
  579. # Data
  580. from .data import DataProcessor as DataProcessor
  581. from .data import InputExample as InputExample
  582. from .data import InputFeatures as InputFeatures
  583. from .data import SingleSentenceClassificationProcessor as SingleSentenceClassificationProcessor
  584. from .data import SquadExample as SquadExample
  585. from .data import SquadFeatures as SquadFeatures
  586. from .data import SquadV1Processor as SquadV1Processor
  587. from .data import SquadV2Processor as SquadV2Processor
  588. from .data import glue_compute_metrics as glue_compute_metrics
  589. from .data import glue_convert_examples_to_features as glue_convert_examples_to_features
  590. from .data import glue_output_modes as glue_output_modes
  591. from .data import glue_processors as glue_processors
  592. from .data import glue_tasks_num_labels as glue_tasks_num_labels
  593. from .data import squad_convert_examples_to_features as squad_convert_examples_to_features
  594. from .data import xnli_compute_metrics as xnli_compute_metrics
  595. from .data import xnli_output_modes as xnli_output_modes
  596. from .data import xnli_processors as xnli_processors
  597. from .data import xnli_tasks_num_labels as xnli_tasks_num_labels
  598. from .data.data_collator import DataCollator as DataCollator
  599. from .data.data_collator import DataCollatorForLanguageModeling as DataCollatorForLanguageModeling
  600. from .data.data_collator import DataCollatorForMultipleChoice as DataCollatorForMultipleChoice
  601. from .data.data_collator import (
  602. DataCollatorForPermutationLanguageModeling as DataCollatorForPermutationLanguageModeling,
  603. )
  604. from .data.data_collator import DataCollatorForSeq2Seq as DataCollatorForSeq2Seq
  605. from .data.data_collator import DataCollatorForSOP as DataCollatorForSOP
  606. from .data.data_collator import DataCollatorForTokenClassification as DataCollatorForTokenClassification
  607. from .data.data_collator import DataCollatorForWholeWordMask as DataCollatorForWholeWordMask
  608. from .data.data_collator import DataCollatorWithFlattening as DataCollatorWithFlattening
  609. from .data.data_collator import DataCollatorWithPadding as DataCollatorWithPadding
  610. from .data.data_collator import DefaultDataCollator as DefaultDataCollator
  611. from .data.data_collator import default_data_collator as default_data_collator
  612. from .data.datasets import GlueDataset as GlueDataset
  613. from .data.datasets import GlueDataTrainingArguments as GlueDataTrainingArguments
  614. from .data.datasets import LineByLineTextDataset as LineByLineTextDataset
  615. from .data.datasets import LineByLineWithRefDataset as LineByLineWithRefDataset
  616. from .data.datasets import LineByLineWithSOPTextDataset as LineByLineWithSOPTextDataset
  617. from .data.datasets import SquadDataset as SquadDataset
  618. from .data.datasets import SquadDataTrainingArguments as SquadDataTrainingArguments
  619. from .data.datasets import TextDataset as TextDataset
  620. from .data.datasets import TextDatasetForNextSentencePrediction as TextDatasetForNextSentencePrediction
  621. from .feature_extraction_sequence_utils import SequenceFeatureExtractor as SequenceFeatureExtractor
  622. # Feature Extractor
  623. from .feature_extraction_utils import BatchFeature as BatchFeature
  624. from .feature_extraction_utils import FeatureExtractionMixin as FeatureExtractionMixin
  625. # Generation
  626. from .generation import AlternatingCodebooksLogitsProcessor as AlternatingCodebooksLogitsProcessor
  627. from .generation import AsyncTextIteratorStreamer as AsyncTextIteratorStreamer
  628. from .generation import BayesianDetectorConfig as BayesianDetectorConfig
  629. from .generation import BayesianDetectorModel as BayesianDetectorModel
  630. from .generation import BeamScorer as BeamScorer
  631. from .generation import ClassifierFreeGuidanceLogitsProcessor as ClassifierFreeGuidanceLogitsProcessor
  632. from .generation import CompileConfig as CompileConfig
  633. from .generation import ConstrainedBeamSearchScorer as ConstrainedBeamSearchScorer
  634. from .generation import Constraint as Constraint
  635. from .generation import ConstraintListState as ConstraintListState
  636. from .generation import DisjunctiveConstraint as DisjunctiveConstraint
  637. from .generation import EncoderNoRepeatNGramLogitsProcessor as EncoderNoRepeatNGramLogitsProcessor
  638. from .generation import EncoderRepetitionPenaltyLogitsProcessor as EncoderRepetitionPenaltyLogitsProcessor
  639. from .generation import EosTokenCriteria as EosTokenCriteria
  640. from .generation import EpsilonLogitsWarper as EpsilonLogitsWarper
  641. from .generation import EtaLogitsWarper as EtaLogitsWarper
  642. from .generation import ExponentialDecayLengthPenalty as ExponentialDecayLengthPenalty
  643. from .generation import FlaxForcedBOSTokenLogitsProcessor as FlaxForcedBOSTokenLogitsProcessor
  644. from .generation import FlaxForcedEOSTokenLogitsProcessor as FlaxForcedEOSTokenLogitsProcessor
  645. from .generation import FlaxForceTokensLogitsProcessor as FlaxForceTokensLogitsProcessor
  646. from .generation import FlaxGenerationMixin as FlaxGenerationMixin
  647. from .generation import FlaxLogitsProcessor as FlaxLogitsProcessor
  648. from .generation import FlaxLogitsProcessorList as FlaxLogitsProcessorList
  649. from .generation import FlaxLogitsWarper as FlaxLogitsWarper
  650. from .generation import FlaxMinLengthLogitsProcessor as FlaxMinLengthLogitsProcessor
  651. from .generation import FlaxSuppressTokensAtBeginLogitsProcessor as FlaxSuppressTokensAtBeginLogitsProcessor
  652. from .generation import FlaxSuppressTokensLogitsProcessor as FlaxSuppressTokensLogitsProcessor
  653. from .generation import FlaxTemperatureLogitsWarper as FlaxTemperatureLogitsWarper
  654. from .generation import FlaxTopKLogitsWarper as FlaxTopKLogitsWarper
  655. from .generation import FlaxTopPLogitsWarper as FlaxTopPLogitsWarper
  656. from .generation import FlaxWhisperTimeStampLogitsProcessor as FlaxWhisperTimeStampLogitsProcessor
  657. from .generation import ForcedBOSTokenLogitsProcessor as ForcedBOSTokenLogitsProcessor
  658. from .generation import ForcedEOSTokenLogitsProcessor as ForcedEOSTokenLogitsProcessor
  659. from .generation import GenerationConfig as GenerationConfig
  660. from .generation import GenerationMixin as GenerationMixin
  661. from .generation import InfNanRemoveLogitsProcessor as InfNanRemoveLogitsProcessor
  662. from .generation import LogitNormalization as LogitNormalization
  663. from .generation import LogitsProcessor as LogitsProcessor
  664. from .generation import LogitsProcessorList as LogitsProcessorList
  665. from .generation import MaxLengthCriteria as MaxLengthCriteria
  666. from .generation import MaxTimeCriteria as MaxTimeCriteria
  667. from .generation import MinLengthLogitsProcessor as MinLengthLogitsProcessor
  668. from .generation import MinNewTokensLengthLogitsProcessor as MinNewTokensLengthLogitsProcessor
  669. from .generation import MinPLogitsWarper as MinPLogitsWarper
  670. from .generation import NoBadWordsLogitsProcessor as NoBadWordsLogitsProcessor
  671. from .generation import NoRepeatNGramLogitsProcessor as NoRepeatNGramLogitsProcessor
  672. from .generation import PhrasalConstraint as PhrasalConstraint
  673. from .generation import PrefixConstrainedLogitsProcessor as PrefixConstrainedLogitsProcessor
  674. from .generation import RepetitionPenaltyLogitsProcessor as RepetitionPenaltyLogitsProcessor
  675. from .generation import SequenceBiasLogitsProcessor as SequenceBiasLogitsProcessor
  676. from .generation import StoppingCriteria as StoppingCriteria
  677. from .generation import StoppingCriteriaList as StoppingCriteriaList
  678. from .generation import StopStringCriteria as StopStringCriteria
  679. from .generation import SuppressTokensAtBeginLogitsProcessor as SuppressTokensAtBeginLogitsProcessor
  680. from .generation import SuppressTokensLogitsProcessor as SuppressTokensLogitsProcessor
  681. from .generation import SynthIDTextWatermarkDetector as SynthIDTextWatermarkDetector
  682. from .generation import SynthIDTextWatermarkingConfig as SynthIDTextWatermarkingConfig
  683. from .generation import SynthIDTextWatermarkLogitsProcessor as SynthIDTextWatermarkLogitsProcessor
  684. from .generation import TemperatureLogitsWarper as TemperatureLogitsWarper
  685. from .generation import TextIteratorStreamer as TextIteratorStreamer
  686. from .generation import TextStreamer as TextStreamer
  687. from .generation import TFForcedBOSTokenLogitsProcessor as TFForcedBOSTokenLogitsProcessor
  688. from .generation import TFForcedEOSTokenLogitsProcessor as TFForcedEOSTokenLogitsProcessor
  689. from .generation import TFForceTokensLogitsProcessor as TFForceTokensLogitsProcessor
  690. from .generation import TFGenerationMixin as TFGenerationMixin
  691. from .generation import TFLogitsProcessor as TFLogitsProcessor
  692. from .generation import TFLogitsProcessorList as TFLogitsProcessorList
  693. from .generation import TFLogitsWarper as TFLogitsWarper
  694. from .generation import TFMinLengthLogitsProcessor as TFMinLengthLogitsProcessor
  695. from .generation import TFNoBadWordsLogitsProcessor as TFNoBadWordsLogitsProcessor
  696. from .generation import TFNoRepeatNGramLogitsProcessor as TFNoRepeatNGramLogitsProcessor
  697. from .generation import TFRepetitionPenaltyLogitsProcessor as TFRepetitionPenaltyLogitsProcessor
  698. from .generation import TFSuppressTokensAtBeginLogitsProcessor as TFSuppressTokensAtBeginLogitsProcessor
  699. from .generation import TFSuppressTokensLogitsProcessor as TFSuppressTokensLogitsProcessor
  700. from .generation import TFTemperatureLogitsWarper as TFTemperatureLogitsWarper
  701. from .generation import TFTopKLogitsWarper as TFTopKLogitsWarper
  702. from .generation import TFTopPLogitsWarper as TFTopPLogitsWarper
  703. from .generation import TopKLogitsWarper as TopKLogitsWarper
  704. from .generation import TopPLogitsWarper as TopPLogitsWarper
  705. from .generation import TypicalLogitsWarper as TypicalLogitsWarper
  706. from .generation import (
  707. UnbatchedClassifierFreeGuidanceLogitsProcessor as UnbatchedClassifierFreeGuidanceLogitsProcessor,
  708. )
  709. from .generation import WatermarkDetector as WatermarkDetector
  710. from .generation import WatermarkingConfig as WatermarkingConfig
  711. from .generation import WatermarkLogitsProcessor as WatermarkLogitsProcessor
  712. from .generation import WhisperTimeStampLogitsProcessor as WhisperTimeStampLogitsProcessor
  713. from .hf_argparser import HfArgumentParser as HfArgumentParser
  714. from .image_processing_base import ImageProcessingMixin as ImageProcessingMixin
  715. from .image_processing_utils import BaseImageProcessor as BaseImageProcessor
  716. from .image_processing_utils_fast import BaseImageProcessorFast as BaseImageProcessorFast
  717. from .image_utils import ImageFeatureExtractionMixin as ImageFeatureExtractionMixin
  718. # Integrations
  719. from .integrations import is_clearml_available as is_clearml_available
  720. from .integrations import is_comet_available as is_comet_available
  721. from .integrations import is_dvclive_available as is_dvclive_available
  722. from .integrations import is_neptune_available as is_neptune_available
  723. from .integrations import is_optuna_available as is_optuna_available
  724. from .integrations import is_ray_available as is_ray_available
  725. from .integrations import is_ray_tune_available as is_ray_tune_available
  726. from .integrations import is_sigopt_available as is_sigopt_available
  727. from .integrations import is_swanlab_available as is_swanlab_available
  728. from .integrations import is_tensorboard_available as is_tensorboard_available
  729. from .integrations import is_trackio_available as is_trackio_available
  730. from .integrations import is_wandb_available as is_wandb_available
  731. from .integrations.executorch import TorchExportableModuleWithStaticCache as TorchExportableModuleWithStaticCache
  732. from .integrations.executorch import convert_and_export_with_cache as convert_and_export_with_cache
  733. from .keras_callbacks import KerasMetricCallback as KerasMetricCallback
  734. from .keras_callbacks import PushToHubCallback as PushToHubCallback
  735. from .masking_utils import AttentionMaskInterface as AttentionMaskInterface
  736. from .model_debugging_utils import model_addition_debugger_context as model_addition_debugger_context
  737. # Model Cards
  738. from .modelcard import ModelCard as ModelCard
  739. from .modeling_flax_utils import FlaxPreTrainedModel as FlaxPreTrainedModel
  740. from .modeling_layers import GradientCheckpointingLayer as GradientCheckpointingLayer
  741. from .modeling_rope_utils import ROPE_INIT_FUNCTIONS as ROPE_INIT_FUNCTIONS
  742. from .modeling_rope_utils import dynamic_rope_update as dynamic_rope_update
  743. # TF 2.0 <=> PyTorch conversion utilities
  744. from .modeling_tf_pytorch_utils import (
  745. convert_tf_weight_name_to_pt_weight_name as convert_tf_weight_name_to_pt_weight_name,
  746. )
  747. from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model as load_pytorch_checkpoint_in_tf2_model
  748. from .modeling_tf_pytorch_utils import load_pytorch_model_in_tf2_model as load_pytorch_model_in_tf2_model
  749. from .modeling_tf_pytorch_utils import load_pytorch_weights_in_tf2_model as load_pytorch_weights_in_tf2_model
  750. from .modeling_tf_pytorch_utils import load_tf2_checkpoint_in_pytorch_model as load_tf2_checkpoint_in_pytorch_model
  751. from .modeling_tf_pytorch_utils import load_tf2_model_in_pytorch_model as load_tf2_model_in_pytorch_model
  752. from .modeling_tf_pytorch_utils import load_tf2_weights_in_pytorch_model as load_tf2_weights_in_pytorch_model
  753. from .modeling_tf_utils import TFPreTrainedModel as TFPreTrainedModel
  754. from .modeling_tf_utils import TFSequenceSummary as TFSequenceSummary
  755. from .modeling_tf_utils import TFSharedEmbeddings as TFSharedEmbeddings
  756. from .modeling_tf_utils import shape_list as shape_list
  757. from .modeling_utils import AttentionInterface as AttentionInterface
  758. from .modeling_utils import PreTrainedModel as PreTrainedModel
  759. from .models import *
  760. from .models.mamba.modeling_mamba import MambaCache as MambaCache
  761. from .models.timm_wrapper import TimmWrapperImageProcessor as TimmWrapperImageProcessor
  762. # Optimization
  763. from .optimization import Adafactor as Adafactor
  764. from .optimization import get_constant_schedule as get_constant_schedule
  765. from .optimization import get_constant_schedule_with_warmup as get_constant_schedule_with_warmup
  766. from .optimization import get_cosine_schedule_with_warmup as get_cosine_schedule_with_warmup
  767. from .optimization import (
  768. get_cosine_with_hard_restarts_schedule_with_warmup as get_cosine_with_hard_restarts_schedule_with_warmup,
  769. )
  770. from .optimization import (
  771. get_cosine_with_min_lr_schedule_with_warmup as get_cosine_with_min_lr_schedule_with_warmup,
  772. )
  773. from .optimization import (
  774. get_cosine_with_min_lr_schedule_with_warmup_lr_rate as get_cosine_with_min_lr_schedule_with_warmup_lr_rate,
  775. )
  776. from .optimization import get_inverse_sqrt_schedule as get_inverse_sqrt_schedule
  777. from .optimization import get_linear_schedule_with_warmup as get_linear_schedule_with_warmup
  778. from .optimization import get_polynomial_decay_schedule_with_warmup as get_polynomial_decay_schedule_with_warmup
  779. from .optimization import get_scheduler as get_scheduler
  780. from .optimization import get_wsd_schedule as get_wsd_schedule
  781. # Optimization
  782. from .optimization_tf import AdamWeightDecay as AdamWeightDecay
  783. from .optimization_tf import GradientAccumulator as GradientAccumulator
  784. from .optimization_tf import WarmUp as WarmUp
  785. from .optimization_tf import create_optimizer as create_optimizer
  786. # Pipelines
  787. from .pipelines import AudioClassificationPipeline as AudioClassificationPipeline
  788. from .pipelines import AutomaticSpeechRecognitionPipeline as AutomaticSpeechRecognitionPipeline
  789. from .pipelines import CsvPipelineDataFormat as CsvPipelineDataFormat
  790. from .pipelines import DepthEstimationPipeline as DepthEstimationPipeline
  791. from .pipelines import DocumentQuestionAnsweringPipeline as DocumentQuestionAnsweringPipeline
  792. from .pipelines import FeatureExtractionPipeline as FeatureExtractionPipeline
  793. from .pipelines import FillMaskPipeline as FillMaskPipeline
  794. from .pipelines import ImageClassificationPipeline as ImageClassificationPipeline
  795. from .pipelines import ImageFeatureExtractionPipeline as ImageFeatureExtractionPipeline
  796. from .pipelines import ImageSegmentationPipeline as ImageSegmentationPipeline
  797. from .pipelines import ImageTextToTextPipeline as ImageTextToTextPipeline
  798. from .pipelines import ImageToImagePipeline as ImageToImagePipeline
  799. from .pipelines import ImageToTextPipeline as ImageToTextPipeline
  800. from .pipelines import JsonPipelineDataFormat as JsonPipelineDataFormat
  801. from .pipelines import KeypointMatchingPipeline as KeypointMatchingPipeline
  802. from .pipelines import MaskGenerationPipeline as MaskGenerationPipeline
  803. from .pipelines import NerPipeline as NerPipeline
  804. from .pipelines import ObjectDetectionPipeline as ObjectDetectionPipeline
  805. from .pipelines import PipedPipelineDataFormat as PipedPipelineDataFormat
  806. from .pipelines import Pipeline as Pipeline
  807. from .pipelines import PipelineDataFormat as PipelineDataFormat
  808. from .pipelines import QuestionAnsweringPipeline as QuestionAnsweringPipeline
  809. from .pipelines import SummarizationPipeline as SummarizationPipeline
  810. from .pipelines import TableQuestionAnsweringPipeline as TableQuestionAnsweringPipeline
  811. from .pipelines import Text2TextGenerationPipeline as Text2TextGenerationPipeline
  812. from .pipelines import TextClassificationPipeline as TextClassificationPipeline
  813. from .pipelines import TextGenerationPipeline as TextGenerationPipeline
  814. from .pipelines import TextToAudioPipeline as TextToAudioPipeline
  815. from .pipelines import TokenClassificationPipeline as TokenClassificationPipeline
  816. from .pipelines import TranslationPipeline as TranslationPipeline
  817. from .pipelines import VideoClassificationPipeline as VideoClassificationPipeline
  818. from .pipelines import VisualQuestionAnsweringPipeline as VisualQuestionAnsweringPipeline
  819. from .pipelines import ZeroShotAudioClassificationPipeline as ZeroShotAudioClassificationPipeline
  820. from .pipelines import ZeroShotClassificationPipeline as ZeroShotClassificationPipeline
  821. from .pipelines import ZeroShotImageClassificationPipeline as ZeroShotImageClassificationPipeline
  822. from .pipelines import ZeroShotObjectDetectionPipeline as ZeroShotObjectDetectionPipeline
  823. from .pipelines import pipeline as pipeline
  824. from .processing_utils import ProcessorMixin as ProcessorMixin
  825. from .pytorch_utils import Conv1D as Conv1D
  826. from .pytorch_utils import apply_chunking_to_forward as apply_chunking_to_forward
  827. from .pytorch_utils import prune_layer as prune_layer
  828. # Tokenization
  829. from .tokenization_utils import PreTrainedTokenizer as PreTrainedTokenizer
  830. from .tokenization_utils_base import AddedToken as AddedToken
  831. from .tokenization_utils_base import BatchEncoding as BatchEncoding
  832. from .tokenization_utils_base import CharSpan as CharSpan
  833. from .tokenization_utils_base import PreTrainedTokenizerBase as PreTrainedTokenizerBase
  834. from .tokenization_utils_base import SpecialTokensMixin as SpecialTokensMixin
  835. from .tokenization_utils_base import TokenSpan as TokenSpan
  836. from .tokenization_utils_fast import PreTrainedTokenizerFast as PreTrainedTokenizerFast
  837. # Trainer
  838. from .trainer import Trainer as Trainer
  839. # Trainer
  840. from .trainer_callback import DefaultFlowCallback as DefaultFlowCallback
  841. from .trainer_callback import EarlyStoppingCallback as EarlyStoppingCallback
  842. from .trainer_callback import PrinterCallback as PrinterCallback
  843. from .trainer_callback import ProgressCallback as ProgressCallback
  844. from .trainer_callback import TrainerCallback as TrainerCallback
  845. from .trainer_callback import TrainerControl as TrainerControl
  846. from .trainer_callback import TrainerState as TrainerState
  847. from .trainer_pt_utils import torch_distributed_zero_first as torch_distributed_zero_first
  848. from .trainer_seq2seq import Seq2SeqTrainer as Seq2SeqTrainer
  849. from .trainer_utils import EvalPrediction as EvalPrediction
  850. from .trainer_utils import IntervalStrategy as IntervalStrategy
  851. from .trainer_utils import SchedulerType as SchedulerType
  852. from .trainer_utils import enable_full_determinism as enable_full_determinism
  853. from .trainer_utils import set_seed as set_seed
  854. from .training_args import TrainingArguments as TrainingArguments
  855. from .training_args_seq2seq import Seq2SeqTrainingArguments as Seq2SeqTrainingArguments
  856. from .training_args_tf import TFTrainingArguments as TFTrainingArguments
  857. # Files and general utilities
  858. from .utils import CONFIG_NAME as CONFIG_NAME
  859. from .utils import MODEL_CARD_NAME as MODEL_CARD_NAME
  860. from .utils import PYTORCH_PRETRAINED_BERT_CACHE as PYTORCH_PRETRAINED_BERT_CACHE
  861. from .utils import PYTORCH_TRANSFORMERS_CACHE as PYTORCH_TRANSFORMERS_CACHE
  862. from .utils import SPIECE_UNDERLINE as SPIECE_UNDERLINE
  863. from .utils import TF2_WEIGHTS_NAME as TF2_WEIGHTS_NAME
  864. from .utils import TF_WEIGHTS_NAME as TF_WEIGHTS_NAME
  865. from .utils import TRANSFORMERS_CACHE as TRANSFORMERS_CACHE
  866. from .utils import WEIGHTS_NAME as WEIGHTS_NAME
  867. from .utils import TensorType as TensorType
  868. from .utils import add_end_docstrings as add_end_docstrings
  869. from .utils import add_start_docstrings as add_start_docstrings
  870. from .utils import is_apex_available as is_apex_available
  871. from .utils import is_av_available as is_av_available
  872. from .utils import is_datasets_available as is_datasets_available
  873. from .utils import is_faiss_available as is_faiss_available
  874. from .utils import is_matplotlib_available as is_matplotlib_available
  875. from .utils import is_phonemizer_available as is_phonemizer_available
  876. from .utils import is_psutil_available as is_psutil_available
  877. from .utils import is_py3nvml_available as is_py3nvml_available
  878. from .utils import is_pyctcdecode_available as is_pyctcdecode_available
  879. from .utils import is_sacremoses_available as is_sacremoses_available
  880. from .utils import is_safetensors_available as is_safetensors_available
  881. from .utils import is_sklearn_available as is_sklearn_available
  882. from .utils import is_torch_hpu_available as is_torch_hpu_available
  883. from .utils import is_torch_mlu_available as is_torch_mlu_available
  884. from .utils import is_torch_musa_available as is_torch_musa_available
  885. from .utils import is_torch_neuroncore_available as is_torch_neuroncore_available
  886. from .utils import is_torch_npu_available as is_torch_npu_available
  887. from .utils import is_torch_xla_available as is_torch_xla_available
  888. from .utils import is_torch_xpu_available as is_torch_xpu_available
  889. # bitsandbytes config
  890. from .utils.quantization_config import AqlmConfig as AqlmConfig
  891. from .utils.quantization_config import AutoRoundConfig as AutoRoundConfig
  892. from .utils.quantization_config import AwqConfig as AwqConfig
  893. from .utils.quantization_config import BitNetQuantConfig as BitNetQuantConfig
  894. from .utils.quantization_config import BitsAndBytesConfig as BitsAndBytesConfig
  895. from .utils.quantization_config import CompressedTensorsConfig as CompressedTensorsConfig
  896. from .utils.quantization_config import EetqConfig as EetqConfig
  897. from .utils.quantization_config import FbgemmFp8Config as FbgemmFp8Config
  898. from .utils.quantization_config import FineGrainedFP8Config as FineGrainedFP8Config
  899. from .utils.quantization_config import FPQuantConfig as FPQuantConfig
  900. from .utils.quantization_config import GPTQConfig as GPTQConfig
  901. from .utils.quantization_config import HiggsConfig as HiggsConfig
  902. from .utils.quantization_config import HqqConfig as HqqConfig
  903. from .utils.quantization_config import QuantoConfig as QuantoConfig
  904. from .utils.quantization_config import QuarkConfig as QuarkConfig
  905. from .utils.quantization_config import SpQRConfig as SpQRConfig
  906. from .utils.quantization_config import TorchAoConfig as TorchAoConfig
  907. from .utils.quantization_config import VptqConfig as VptqConfig
  908. from .video_processing_utils import BaseVideoProcessor as BaseVideoProcessor
  909. else:
  910. import sys
  911. _import_structure = {k: set(v) for k, v in _import_structure.items()}
  912. import_structure = define_import_structure(Path(__file__).parent / "models", prefix="models")
  913. import_structure[frozenset({})].update(_import_structure)
  914. sys.modules[__name__] = _LazyModule(
  915. __name__,
  916. globals()["__file__"],
  917. import_structure,
  918. module_spec=__spec__,
  919. extra_objects={"__version__": __version__},
  920. )
  921. if not is_tf_available() and not is_torch_available() and not is_flax_available():
  922. logger.warning_advice(
  923. "None of PyTorch, TensorFlow >= 2.0, or Flax have been found. "
  924. "Models won't be available and only tokenizers, configuration "
  925. "and file/data utilities can be used."
  926. )