modeling_auto.py 96 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382
  1. # coding=utf-8
  2. # Copyright 2018 The HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Auto Model class."""
  16. import os
  17. import warnings
  18. from collections import OrderedDict
  19. from typing import TYPE_CHECKING, Union
  20. from ...utils import logging
  21. from .auto_factory import (
  22. _BaseAutoBackboneClass,
  23. _BaseAutoModelClass,
  24. _LazyAutoMapping,
  25. auto_class_update,
  26. )
  27. from .configuration_auto import CONFIG_MAPPING_NAMES
  28. if TYPE_CHECKING:
  29. from ...generation import GenerationMixin
  30. from ...modeling_utils import PreTrainedModel
  31. # class for better type annotations
  32. class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
  33. pass
  34. logger = logging.get_logger(__name__)
  35. MODEL_MAPPING_NAMES = OrderedDict(
  36. [
  37. # Base model mapping
  38. ("aimv2", "Aimv2Model"),
  39. ("aimv2_vision_model", "Aimv2VisionModel"),
  40. ("albert", "AlbertModel"),
  41. ("align", "AlignModel"),
  42. ("altclip", "AltCLIPModel"),
  43. ("apertus", "ApertusModel"),
  44. ("arcee", "ArceeModel"),
  45. ("aria", "AriaModel"),
  46. ("aria_text", "AriaTextModel"),
  47. ("audio-spectrogram-transformer", "ASTModel"),
  48. ("autoformer", "AutoformerModel"),
  49. ("aya_vision", "AyaVisionModel"),
  50. ("bamba", "BambaModel"),
  51. ("bark", "BarkModel"),
  52. ("bart", "BartModel"),
  53. ("beit", "BeitModel"),
  54. ("bert", "BertModel"),
  55. ("bert-generation", "BertGenerationEncoder"),
  56. ("big_bird", "BigBirdModel"),
  57. ("bigbird_pegasus", "BigBirdPegasusModel"),
  58. ("biogpt", "BioGptModel"),
  59. ("bit", "BitModel"),
  60. ("bitnet", "BitNetModel"),
  61. ("blenderbot", "BlenderbotModel"),
  62. ("blenderbot-small", "BlenderbotSmallModel"),
  63. ("blip", "BlipModel"),
  64. ("blip-2", "Blip2Model"),
  65. ("blip_2_qformer", "Blip2QFormerModel"),
  66. ("bloom", "BloomModel"),
  67. ("blt", "BltModel"),
  68. ("bridgetower", "BridgeTowerModel"),
  69. ("bros", "BrosModel"),
  70. ("camembert", "CamembertModel"),
  71. ("canine", "CanineModel"),
  72. ("chameleon", "ChameleonModel"),
  73. ("chinese_clip", "ChineseCLIPModel"),
  74. ("chinese_clip_vision_model", "ChineseCLIPVisionModel"),
  75. ("clap", "ClapModel"),
  76. ("clip", "CLIPModel"),
  77. ("clip_text_model", "CLIPTextModel"),
  78. ("clip_vision_model", "CLIPVisionModel"),
  79. ("clipseg", "CLIPSegModel"),
  80. ("clvp", "ClvpModelForConditionalGeneration"),
  81. ("code_llama", "LlamaModel"),
  82. ("codegen", "CodeGenModel"),
  83. ("cohere", "CohereModel"),
  84. ("cohere2", "Cohere2Model"),
  85. ("cohere2_vision", "Cohere2VisionModel"),
  86. ("conditional_detr", "ConditionalDetrModel"),
  87. ("convbert", "ConvBertModel"),
  88. ("convnext", "ConvNextModel"),
  89. ("convnextv2", "ConvNextV2Model"),
  90. ("cpmant", "CpmAntModel"),
  91. ("csm", "CsmForConditionalGeneration"),
  92. ("ctrl", "CTRLModel"),
  93. ("cvt", "CvtModel"),
  94. ("d_fine", "DFineModel"),
  95. ("dab-detr", "DabDetrModel"),
  96. ("dac", "DacModel"),
  97. ("data2vec-audio", "Data2VecAudioModel"),
  98. ("data2vec-text", "Data2VecTextModel"),
  99. ("data2vec-vision", "Data2VecVisionModel"),
  100. ("dbrx", "DbrxModel"),
  101. ("deberta", "DebertaModel"),
  102. ("deberta-v2", "DebertaV2Model"),
  103. ("decision_transformer", "DecisionTransformerModel"),
  104. ("deepseek_v2", "DeepseekV2Model"),
  105. ("deepseek_v3", "DeepseekV3Model"),
  106. ("deepseek_vl", "DeepseekVLModel"),
  107. ("deepseek_vl_hybrid", "DeepseekVLHybridModel"),
  108. ("deformable_detr", "DeformableDetrModel"),
  109. ("deit", "DeiTModel"),
  110. ("depth_pro", "DepthProModel"),
  111. ("deta", "DetaModel"),
  112. ("detr", "DetrModel"),
  113. ("dia", "DiaModel"),
  114. ("diffllama", "DiffLlamaModel"),
  115. ("dinat", "DinatModel"),
  116. ("dinov2", "Dinov2Model"),
  117. ("dinov2_with_registers", "Dinov2WithRegistersModel"),
  118. ("dinov3_convnext", "DINOv3ConvNextModel"),
  119. ("dinov3_vit", "DINOv3ViTModel"),
  120. ("distilbert", "DistilBertModel"),
  121. ("doge", "DogeModel"),
  122. ("donut-swin", "DonutSwinModel"),
  123. ("dots1", "Dots1Model"),
  124. ("dpr", "DPRQuestionEncoder"),
  125. ("dpt", "DPTModel"),
  126. ("edgetam", "EdgeTamModel"),
  127. ("edgetam_video", "EdgeTamVideoModel"),
  128. ("edgetam_vision_model", "EdgeTamVisionModel"),
  129. ("efficientformer", "EfficientFormerModel"),
  130. ("efficientloftr", "EfficientLoFTRModel"),
  131. ("efficientnet", "EfficientNetModel"),
  132. ("electra", "ElectraModel"),
  133. ("emu3", "Emu3Model"),
  134. ("encodec", "EncodecModel"),
  135. ("ernie", "ErnieModel"),
  136. ("ernie4_5", "Ernie4_5Model"),
  137. ("ernie4_5_moe", "Ernie4_5_MoeModel"),
  138. ("ernie_m", "ErnieMModel"),
  139. ("esm", "EsmModel"),
  140. ("evolla", "EvollaModel"),
  141. ("exaone4", "Exaone4Model"),
  142. ("falcon", "FalconModel"),
  143. ("falcon_h1", "FalconH1Model"),
  144. ("falcon_mamba", "FalconMambaModel"),
  145. ("fastspeech2_conformer", "FastSpeech2ConformerModel"),
  146. ("fastspeech2_conformer_with_hifigan", "FastSpeech2ConformerWithHifiGan"),
  147. ("flaubert", "FlaubertModel"),
  148. ("flava", "FlavaModel"),
  149. ("flex_olmo", "FlexOlmoModel"),
  150. ("florence2", "Florence2Model"),
  151. ("fnet", "FNetModel"),
  152. ("focalnet", "FocalNetModel"),
  153. ("fsmt", "FSMTModel"),
  154. ("funnel", ("FunnelModel", "FunnelBaseModel")),
  155. ("fuyu", "FuyuModel"),
  156. ("gemma", "GemmaModel"),
  157. ("gemma2", "Gemma2Model"),
  158. ("gemma3", "Gemma3Model"),
  159. ("gemma3_text", "Gemma3TextModel"),
  160. ("gemma3n", "Gemma3nModel"),
  161. ("gemma3n_audio", "Gemma3nAudioEncoder"),
  162. ("gemma3n_text", "Gemma3nTextModel"),
  163. ("gemma3n_vision", "TimmWrapperModel"),
  164. ("git", "GitModel"),
  165. ("glm", "GlmModel"),
  166. ("glm4", "Glm4Model"),
  167. ("glm4_moe", "Glm4MoeModel"),
  168. ("glm4v", "Glm4vModel"),
  169. ("glm4v_moe", "Glm4vMoeModel"),
  170. ("glm4v_moe_text", "Glm4vMoeTextModel"),
  171. ("glm4v_text", "Glm4vTextModel"),
  172. ("glpn", "GLPNModel"),
  173. ("got_ocr2", "GotOcr2Model"),
  174. ("gpt-sw3", "GPT2Model"),
  175. ("gpt2", "GPT2Model"),
  176. ("gpt_bigcode", "GPTBigCodeModel"),
  177. ("gpt_neo", "GPTNeoModel"),
  178. ("gpt_neox", "GPTNeoXModel"),
  179. ("gpt_neox_japanese", "GPTNeoXJapaneseModel"),
  180. ("gpt_oss", "GptOssModel"),
  181. ("gptj", "GPTJModel"),
  182. ("gptsan-japanese", "GPTSanJapaneseForConditionalGeneration"),
  183. ("granite", "GraniteModel"),
  184. ("granitemoe", "GraniteMoeModel"),
  185. ("granitemoehybrid", "GraniteMoeHybridModel"),
  186. ("granitemoeshared", "GraniteMoeSharedModel"),
  187. ("graphormer", "GraphormerModel"),
  188. ("grounding-dino", "GroundingDinoModel"),
  189. ("groupvit", "GroupViTModel"),
  190. ("helium", "HeliumModel"),
  191. ("hgnet_v2", "HGNetV2Backbone"),
  192. ("hiera", "HieraModel"),
  193. ("hubert", "HubertModel"),
  194. ("hunyuan_v1_dense", "HunYuanDenseV1Model"),
  195. ("hunyuan_v1_moe", "HunYuanMoEV1Model"),
  196. ("ibert", "IBertModel"),
  197. ("idefics", "IdeficsModel"),
  198. ("idefics2", "Idefics2Model"),
  199. ("idefics3", "Idefics3Model"),
  200. ("idefics3_vision", "Idefics3VisionTransformer"),
  201. ("ijepa", "IJepaModel"),
  202. ("imagegpt", "ImageGPTModel"),
  203. ("informer", "InformerModel"),
  204. ("instructblip", "InstructBlipModel"),
  205. ("instructblipvideo", "InstructBlipVideoModel"),
  206. ("internvl", "InternVLModel"),
  207. ("internvl_vision", "InternVLVisionModel"),
  208. ("jamba", "JambaModel"),
  209. ("janus", "JanusModel"),
  210. ("jetmoe", "JetMoeModel"),
  211. ("jukebox", "JukeboxModel"),
  212. ("kosmos-2", "Kosmos2Model"),
  213. ("kosmos-2.5", "Kosmos2_5Model"),
  214. ("kyutai_speech_to_text", "KyutaiSpeechToTextModel"),
  215. ("layoutlm", "LayoutLMModel"),
  216. ("layoutlmv2", "LayoutLMv2Model"),
  217. ("layoutlmv3", "LayoutLMv3Model"),
  218. ("led", "LEDModel"),
  219. ("levit", "LevitModel"),
  220. ("lfm2", "Lfm2Model"),
  221. ("lfm2_vl", "Lfm2VlModel"),
  222. ("lightglue", "LightGlueForKeypointMatching"),
  223. ("lilt", "LiltModel"),
  224. ("llama", "LlamaModel"),
  225. ("llama4", "Llama4ForConditionalGeneration"),
  226. ("llama4_text", "Llama4TextModel"),
  227. ("llava", "LlavaModel"),
  228. ("llava_next", "LlavaNextModel"),
  229. ("llava_next_video", "LlavaNextVideoModel"),
  230. ("llava_onevision", "LlavaOnevisionModel"),
  231. ("longcat_flash", "LongcatFlashModel"),
  232. ("longformer", "LongformerModel"),
  233. ("longt5", "LongT5Model"),
  234. ("luke", "LukeModel"),
  235. ("lxmert", "LxmertModel"),
  236. ("m2m_100", "M2M100Model"),
  237. ("mamba", "MambaModel"),
  238. ("mamba2", "Mamba2Model"),
  239. ("marian", "MarianModel"),
  240. ("markuplm", "MarkupLMModel"),
  241. ("mask2former", "Mask2FormerModel"),
  242. ("maskformer", "MaskFormerModel"),
  243. ("maskformer-swin", "MaskFormerSwinModel"),
  244. ("mbart", "MBartModel"),
  245. ("mctct", "MCTCTModel"),
  246. ("mega", "MegaModel"),
  247. ("megatron-bert", "MegatronBertModel"),
  248. ("metaclip_2", "MetaClip2Model"),
  249. ("mgp-str", "MgpstrForSceneTextRecognition"),
  250. ("mimi", "MimiModel"),
  251. ("minimax", "MiniMaxModel"),
  252. ("ministral", "MinistralModel"),
  253. ("mistral", "MistralModel"),
  254. ("mistral3", "Mistral3Model"),
  255. ("mixtral", "MixtralModel"),
  256. ("mlcd", "MLCDVisionModel"),
  257. ("mllama", "MllamaModel"),
  258. ("mm-grounding-dino", "MMGroundingDinoModel"),
  259. ("mobilebert", "MobileBertModel"),
  260. ("mobilenet_v1", "MobileNetV1Model"),
  261. ("mobilenet_v2", "MobileNetV2Model"),
  262. ("mobilevit", "MobileViTModel"),
  263. ("mobilevitv2", "MobileViTV2Model"),
  264. ("modernbert", "ModernBertModel"),
  265. ("modernbert-decoder", "ModernBertDecoderModel"),
  266. ("moonshine", "MoonshineModel"),
  267. ("moshi", "MoshiModel"),
  268. ("mpnet", "MPNetModel"),
  269. ("mpt", "MptModel"),
  270. ("mra", "MraModel"),
  271. ("mt5", "MT5Model"),
  272. ("musicgen", "MusicgenModel"),
  273. ("musicgen_melody", "MusicgenMelodyModel"),
  274. ("mvp", "MvpModel"),
  275. ("nat", "NatModel"),
  276. ("nemotron", "NemotronModel"),
  277. ("nezha", "NezhaModel"),
  278. ("nllb-moe", "NllbMoeModel"),
  279. ("nystromformer", "NystromformerModel"),
  280. ("olmo", "OlmoModel"),
  281. ("olmo2", "Olmo2Model"),
  282. ("olmo3", "Olmo3Model"),
  283. ("olmoe", "OlmoeModel"),
  284. ("omdet-turbo", "OmDetTurboForObjectDetection"),
  285. ("oneformer", "OneFormerModel"),
  286. ("open-llama", "OpenLlamaModel"),
  287. ("openai-gpt", "OpenAIGPTModel"),
  288. ("opt", "OPTModel"),
  289. ("ovis2", "Ovis2Model"),
  290. ("owlv2", "Owlv2Model"),
  291. ("owlvit", "OwlViTModel"),
  292. ("paligemma", "PaliGemmaModel"),
  293. ("parakeet_ctc", "ParakeetForCTC"),
  294. ("parakeet_encoder", "ParakeetEncoder"),
  295. ("patchtsmixer", "PatchTSMixerModel"),
  296. ("patchtst", "PatchTSTModel"),
  297. ("pegasus", "PegasusModel"),
  298. ("pegasus_x", "PegasusXModel"),
  299. ("perceiver", "PerceiverModel"),
  300. ("perception_encoder", "PerceptionEncoder"),
  301. ("perception_lm", "PerceptionLMModel"),
  302. ("persimmon", "PersimmonModel"),
  303. ("phi", "PhiModel"),
  304. ("phi3", "Phi3Model"),
  305. ("phi4_multimodal", "Phi4MultimodalModel"),
  306. ("phimoe", "PhimoeModel"),
  307. ("pixtral", "PixtralVisionModel"),
  308. ("plbart", "PLBartModel"),
  309. ("poolformer", "PoolFormerModel"),
  310. ("prophetnet", "ProphetNetModel"),
  311. ("pvt", "PvtModel"),
  312. ("pvt_v2", "PvtV2Model"),
  313. ("qdqbert", "QDQBertModel"),
  314. ("qwen2", "Qwen2Model"),
  315. ("qwen2_5_vl", "Qwen2_5_VLModel"),
  316. ("qwen2_5_vl_text", "Qwen2_5_VLTextModel"),
  317. ("qwen2_audio_encoder", "Qwen2AudioEncoder"),
  318. ("qwen2_moe", "Qwen2MoeModel"),
  319. ("qwen2_vl", "Qwen2VLModel"),
  320. ("qwen2_vl_text", "Qwen2VLTextModel"),
  321. ("qwen3", "Qwen3Model"),
  322. ("qwen3_moe", "Qwen3MoeModel"),
  323. ("qwen3_next", "Qwen3NextModel"),
  324. ("qwen3_vl", "Qwen3VLModel"),
  325. ("qwen3_vl_moe", "Qwen3VLMoeModel"),
  326. ("qwen3_vl_moe_text", "Qwen3VLMoeTextModel"),
  327. ("qwen3_vl_text", "Qwen3VLTextModel"),
  328. ("recurrent_gemma", "RecurrentGemmaModel"),
  329. ("reformer", "ReformerModel"),
  330. ("regnet", "RegNetModel"),
  331. ("rembert", "RemBertModel"),
  332. ("resnet", "ResNetModel"),
  333. ("retribert", "RetriBertModel"),
  334. ("roberta", "RobertaModel"),
  335. ("roberta-prelayernorm", "RobertaPreLayerNormModel"),
  336. ("roc_bert", "RoCBertModel"),
  337. ("roformer", "RoFormerModel"),
  338. ("rt_detr", "RTDetrModel"),
  339. ("rt_detr_v2", "RTDetrV2Model"),
  340. ("rwkv", "RwkvModel"),
  341. ("sam", "SamModel"),
  342. ("sam2", "Sam2Model"),
  343. ("sam2_hiera_det_model", "Sam2HieraDetModel"),
  344. ("sam2_video", "Sam2VideoModel"),
  345. ("sam2_vision_model", "Sam2VisionModel"),
  346. ("sam_hq", "SamHQModel"),
  347. ("sam_hq_vision_model", "SamHQVisionModel"),
  348. ("sam_vision_model", "SamVisionModel"),
  349. ("seamless_m4t", "SeamlessM4TModel"),
  350. ("seamless_m4t_v2", "SeamlessM4Tv2Model"),
  351. ("seed_oss", "SeedOssModel"),
  352. ("segformer", "SegformerModel"),
  353. ("seggpt", "SegGptModel"),
  354. ("sew", "SEWModel"),
  355. ("sew-d", "SEWDModel"),
  356. ("siglip", "SiglipModel"),
  357. ("siglip2", "Siglip2Model"),
  358. ("siglip2_vision_model", "Siglip2VisionModel"),
  359. ("siglip_vision_model", "SiglipVisionModel"),
  360. ("smollm3", "SmolLM3Model"),
  361. ("smolvlm", "SmolVLMModel"),
  362. ("smolvlm_vision", "SmolVLMVisionTransformer"),
  363. ("speech_to_text", "Speech2TextModel"),
  364. ("speecht5", "SpeechT5Model"),
  365. ("splinter", "SplinterModel"),
  366. ("squeezebert", "SqueezeBertModel"),
  367. ("stablelm", "StableLmModel"),
  368. ("starcoder2", "Starcoder2Model"),
  369. ("swiftformer", "SwiftFormerModel"),
  370. ("swin", "SwinModel"),
  371. ("swin2sr", "Swin2SRModel"),
  372. ("swinv2", "Swinv2Model"),
  373. ("switch_transformers", "SwitchTransformersModel"),
  374. ("t5", "T5Model"),
  375. ("t5gemma", "T5GemmaModel"),
  376. ("table-transformer", "TableTransformerModel"),
  377. ("tapas", "TapasModel"),
  378. ("textnet", "TextNetModel"),
  379. ("time_series_transformer", "TimeSeriesTransformerModel"),
  380. ("timesfm", "TimesFmModel"),
  381. ("timesformer", "TimesformerModel"),
  382. ("timm_backbone", "TimmBackbone"),
  383. ("timm_wrapper", "TimmWrapperModel"),
  384. ("trajectory_transformer", "TrajectoryTransformerModel"),
  385. ("transfo-xl", "TransfoXLModel"),
  386. ("tvlt", "TvltModel"),
  387. ("tvp", "TvpModel"),
  388. ("udop", "UdopModel"),
  389. ("umt5", "UMT5Model"),
  390. ("unispeech", "UniSpeechModel"),
  391. ("unispeech-sat", "UniSpeechSatModel"),
  392. ("univnet", "UnivNetModel"),
  393. ("van", "VanModel"),
  394. ("vaultgemma", "VaultGemmaModel"),
  395. ("video_llava", "VideoLlavaModel"),
  396. ("videomae", "VideoMAEModel"),
  397. ("vilt", "ViltModel"),
  398. ("vipllava", "VipLlavaModel"),
  399. ("vision-text-dual-encoder", "VisionTextDualEncoderModel"),
  400. ("visual_bert", "VisualBertModel"),
  401. ("vit", "ViTModel"),
  402. ("vit_hybrid", "ViTHybridModel"),
  403. ("vit_mae", "ViTMAEModel"),
  404. ("vit_msn", "ViTMSNModel"),
  405. ("vitdet", "VitDetModel"),
  406. ("vits", "VitsModel"),
  407. ("vivit", "VivitModel"),
  408. ("vjepa2", "VJEPA2Model"),
  409. ("voxtral", "VoxtralForConditionalGeneration"),
  410. ("voxtral_encoder", "VoxtralEncoder"),
  411. ("wav2vec2", "Wav2Vec2Model"),
  412. ("wav2vec2-bert", "Wav2Vec2BertModel"),
  413. ("wav2vec2-conformer", "Wav2Vec2ConformerModel"),
  414. ("wavlm", "WavLMModel"),
  415. ("whisper", "WhisperModel"),
  416. ("xclip", "XCLIPModel"),
  417. ("xcodec", "XcodecModel"),
  418. ("xglm", "XGLMModel"),
  419. ("xlm", "XLMModel"),
  420. ("xlm-prophetnet", "XLMProphetNetModel"),
  421. ("xlm-roberta", "XLMRobertaModel"),
  422. ("xlm-roberta-xl", "XLMRobertaXLModel"),
  423. ("xlnet", "XLNetModel"),
  424. ("xlstm", "xLSTMModel"),
  425. ("xmod", "XmodModel"),
  426. ("yolos", "YolosModel"),
  427. ("yoso", "YosoModel"),
  428. ("zamba", "ZambaModel"),
  429. ("zamba2", "Zamba2Model"),
  430. ]
  431. )
  432. MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
  433. [
  434. # Model for pre-training mapping
  435. ("albert", "AlbertForPreTraining"),
  436. ("bart", "BartForConditionalGeneration"),
  437. ("bert", "BertForPreTraining"),
  438. ("big_bird", "BigBirdForPreTraining"),
  439. ("bloom", "BloomForCausalLM"),
  440. ("camembert", "CamembertForMaskedLM"),
  441. ("colpali", "ColPaliForRetrieval"),
  442. ("colqwen2", "ColQwen2ForRetrieval"),
  443. ("ctrl", "CTRLLMHeadModel"),
  444. ("data2vec-text", "Data2VecTextForMaskedLM"),
  445. ("deberta", "DebertaForMaskedLM"),
  446. ("deberta-v2", "DebertaV2ForMaskedLM"),
  447. ("distilbert", "DistilBertForMaskedLM"),
  448. ("electra", "ElectraForPreTraining"),
  449. ("ernie", "ErnieForPreTraining"),
  450. ("evolla", "EvollaForProteinText2Text"),
  451. ("exaone4", "Exaone4ForCausalLM"),
  452. ("falcon_mamba", "FalconMambaForCausalLM"),
  453. ("flaubert", "FlaubertWithLMHeadModel"),
  454. ("flava", "FlavaForPreTraining"),
  455. ("florence2", "Florence2ForConditionalGeneration"),
  456. ("fnet", "FNetForPreTraining"),
  457. ("fsmt", "FSMTForConditionalGeneration"),
  458. ("funnel", "FunnelForPreTraining"),
  459. ("gemma3", "Gemma3ForConditionalGeneration"),
  460. ("gpt-sw3", "GPT2LMHeadModel"),
  461. ("gpt2", "GPT2LMHeadModel"),
  462. ("gpt_bigcode", "GPTBigCodeForCausalLM"),
  463. ("gptsan-japanese", "GPTSanJapaneseForConditionalGeneration"),
  464. ("hiera", "HieraForPreTraining"),
  465. ("ibert", "IBertForMaskedLM"),
  466. ("idefics", "IdeficsForVisionText2Text"),
  467. ("idefics2", "Idefics2ForConditionalGeneration"),
  468. ("idefics3", "Idefics3ForConditionalGeneration"),
  469. ("janus", "JanusForConditionalGeneration"),
  470. ("layoutlm", "LayoutLMForMaskedLM"),
  471. ("llava", "LlavaForConditionalGeneration"),
  472. ("llava_next", "LlavaNextForConditionalGeneration"),
  473. ("llava_next_video", "LlavaNextVideoForConditionalGeneration"),
  474. ("llava_onevision", "LlavaOnevisionForConditionalGeneration"),
  475. ("longformer", "LongformerForMaskedLM"),
  476. ("luke", "LukeForMaskedLM"),
  477. ("lxmert", "LxmertForPreTraining"),
  478. ("mamba", "MambaForCausalLM"),
  479. ("mamba2", "Mamba2ForCausalLM"),
  480. ("mega", "MegaForMaskedLM"),
  481. ("megatron-bert", "MegatronBertForPreTraining"),
  482. ("mistral3", "Mistral3ForConditionalGeneration"),
  483. ("mllama", "MllamaForConditionalGeneration"),
  484. ("mobilebert", "MobileBertForPreTraining"),
  485. ("mpnet", "MPNetForMaskedLM"),
  486. ("mpt", "MptForCausalLM"),
  487. ("mra", "MraForMaskedLM"),
  488. ("mvp", "MvpForConditionalGeneration"),
  489. ("nezha", "NezhaForPreTraining"),
  490. ("nllb-moe", "NllbMoeForConditionalGeneration"),
  491. ("openai-gpt", "OpenAIGPTLMHeadModel"),
  492. ("paligemma", "PaliGemmaForConditionalGeneration"),
  493. ("qwen2_audio", "Qwen2AudioForConditionalGeneration"),
  494. ("retribert", "RetriBertModel"),
  495. ("roberta", "RobertaForMaskedLM"),
  496. ("roberta-prelayernorm", "RobertaPreLayerNormForMaskedLM"),
  497. ("roc_bert", "RoCBertForPreTraining"),
  498. ("rwkv", "RwkvForCausalLM"),
  499. ("splinter", "SplinterForPreTraining"),
  500. ("squeezebert", "SqueezeBertForMaskedLM"),
  501. ("switch_transformers", "SwitchTransformersForConditionalGeneration"),
  502. ("t5", "T5ForConditionalGeneration"),
  503. ("t5gemma", "T5GemmaForConditionalGeneration"),
  504. ("tapas", "TapasForMaskedLM"),
  505. ("transfo-xl", "TransfoXLLMHeadModel"),
  506. ("tvlt", "TvltForPreTraining"),
  507. ("unispeech", "UniSpeechForPreTraining"),
  508. ("unispeech-sat", "UniSpeechSatForPreTraining"),
  509. ("video_llava", "VideoLlavaForConditionalGeneration"),
  510. ("videomae", "VideoMAEForPreTraining"),
  511. ("vipllava", "VipLlavaForConditionalGeneration"),
  512. ("visual_bert", "VisualBertForPreTraining"),
  513. ("vit_mae", "ViTMAEForPreTraining"),
  514. ("voxtral", "VoxtralForConditionalGeneration"),
  515. ("wav2vec2", "Wav2Vec2ForPreTraining"),
  516. ("wav2vec2-conformer", "Wav2Vec2ConformerForPreTraining"),
  517. ("xlm", "XLMWithLMHeadModel"),
  518. ("xlm-roberta", "XLMRobertaForMaskedLM"),
  519. ("xlm-roberta-xl", "XLMRobertaXLForMaskedLM"),
  520. ("xlnet", "XLNetLMHeadModel"),
  521. ("xlstm", "xLSTMForCausalLM"),
  522. ("xmod", "XmodForMaskedLM"),
  523. ]
  524. )
  525. MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
  526. [
  527. # Model with LM heads mapping
  528. ("albert", "AlbertForMaskedLM"),
  529. ("bart", "BartForConditionalGeneration"),
  530. ("bert", "BertForMaskedLM"),
  531. ("big_bird", "BigBirdForMaskedLM"),
  532. ("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"),
  533. ("blenderbot-small", "BlenderbotSmallForConditionalGeneration"),
  534. ("bloom", "BloomForCausalLM"),
  535. ("camembert", "CamembertForMaskedLM"),
  536. ("codegen", "CodeGenForCausalLM"),
  537. ("convbert", "ConvBertForMaskedLM"),
  538. ("cpmant", "CpmAntForCausalLM"),
  539. ("ctrl", "CTRLLMHeadModel"),
  540. ("data2vec-text", "Data2VecTextForMaskedLM"),
  541. ("deberta", "DebertaForMaskedLM"),
  542. ("deberta-v2", "DebertaV2ForMaskedLM"),
  543. ("dia", "DiaForConditionalGeneration"),
  544. ("distilbert", "DistilBertForMaskedLM"),
  545. ("electra", "ElectraForMaskedLM"),
  546. ("encoder-decoder", "EncoderDecoderModel"),
  547. ("ernie", "ErnieForMaskedLM"),
  548. ("esm", "EsmForMaskedLM"),
  549. ("exaone4", "Exaone4ForCausalLM"),
  550. ("falcon_mamba", "FalconMambaForCausalLM"),
  551. ("flaubert", "FlaubertWithLMHeadModel"),
  552. ("fnet", "FNetForMaskedLM"),
  553. ("fsmt", "FSMTForConditionalGeneration"),
  554. ("funnel", "FunnelForMaskedLM"),
  555. ("git", "GitForCausalLM"),
  556. ("gpt-sw3", "GPT2LMHeadModel"),
  557. ("gpt2", "GPT2LMHeadModel"),
  558. ("gpt_bigcode", "GPTBigCodeForCausalLM"),
  559. ("gpt_neo", "GPTNeoForCausalLM"),
  560. ("gpt_neox", "GPTNeoXForCausalLM"),
  561. ("gpt_neox_japanese", "GPTNeoXJapaneseForCausalLM"),
  562. ("gptj", "GPTJForCausalLM"),
  563. ("gptsan-japanese", "GPTSanJapaneseForConditionalGeneration"),
  564. ("ibert", "IBertForMaskedLM"),
  565. ("layoutlm", "LayoutLMForMaskedLM"),
  566. ("led", "LEDForConditionalGeneration"),
  567. ("longformer", "LongformerForMaskedLM"),
  568. ("longt5", "LongT5ForConditionalGeneration"),
  569. ("luke", "LukeForMaskedLM"),
  570. ("m2m_100", "M2M100ForConditionalGeneration"),
  571. ("mamba", "MambaForCausalLM"),
  572. ("mamba2", "Mamba2ForCausalLM"),
  573. ("marian", "MarianMTModel"),
  574. ("mega", "MegaForMaskedLM"),
  575. ("megatron-bert", "MegatronBertForCausalLM"),
  576. ("mobilebert", "MobileBertForMaskedLM"),
  577. ("moonshine", "MoonshineForConditionalGeneration"),
  578. ("mpnet", "MPNetForMaskedLM"),
  579. ("mpt", "MptForCausalLM"),
  580. ("mra", "MraForMaskedLM"),
  581. ("mvp", "MvpForConditionalGeneration"),
  582. ("nezha", "NezhaForMaskedLM"),
  583. ("nllb-moe", "NllbMoeForConditionalGeneration"),
  584. ("nystromformer", "NystromformerForMaskedLM"),
  585. ("openai-gpt", "OpenAIGPTLMHeadModel"),
  586. ("pegasus_x", "PegasusXForConditionalGeneration"),
  587. ("plbart", "PLBartForConditionalGeneration"),
  588. ("pop2piano", "Pop2PianoForConditionalGeneration"),
  589. ("qdqbert", "QDQBertForMaskedLM"),
  590. ("reformer", "ReformerModelWithLMHead"),
  591. ("rembert", "RemBertForMaskedLM"),
  592. ("roberta", "RobertaForMaskedLM"),
  593. ("roberta-prelayernorm", "RobertaPreLayerNormForMaskedLM"),
  594. ("roc_bert", "RoCBertForMaskedLM"),
  595. ("roformer", "RoFormerForMaskedLM"),
  596. ("rwkv", "RwkvForCausalLM"),
  597. ("speech_to_text", "Speech2TextForConditionalGeneration"),
  598. ("squeezebert", "SqueezeBertForMaskedLM"),
  599. ("switch_transformers", "SwitchTransformersForConditionalGeneration"),
  600. ("t5", "T5ForConditionalGeneration"),
  601. ("t5gemma", "T5GemmaForConditionalGeneration"),
  602. ("tapas", "TapasForMaskedLM"),
  603. ("transfo-xl", "TransfoXLLMHeadModel"),
  604. ("wav2vec2", "Wav2Vec2ForMaskedLM"),
  605. ("whisper", "WhisperForConditionalGeneration"),
  606. ("xlm", "XLMWithLMHeadModel"),
  607. ("xlm-roberta", "XLMRobertaForMaskedLM"),
  608. ("xlm-roberta-xl", "XLMRobertaXLForMaskedLM"),
  609. ("xlnet", "XLNetLMHeadModel"),
  610. ("xmod", "XmodForMaskedLM"),
  611. ("yoso", "YosoForMaskedLM"),
  612. ]
  613. )
  614. MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
  615. [
  616. # Model for Causal LM mapping
  617. ("apertus", "ApertusForCausalLM"),
  618. ("arcee", "ArceeForCausalLM"),
  619. ("aria_text", "AriaTextForCausalLM"),
  620. ("bamba", "BambaForCausalLM"),
  621. ("bart", "BartForCausalLM"),
  622. ("bert", "BertLMHeadModel"),
  623. ("bert-generation", "BertGenerationDecoder"),
  624. ("big_bird", "BigBirdForCausalLM"),
  625. ("bigbird_pegasus", "BigBirdPegasusForCausalLM"),
  626. ("biogpt", "BioGptForCausalLM"),
  627. ("bitnet", "BitNetForCausalLM"),
  628. ("blenderbot", "BlenderbotForCausalLM"),
  629. ("blenderbot-small", "BlenderbotSmallForCausalLM"),
  630. ("bloom", "BloomForCausalLM"),
  631. ("blt", "BltForCausalLM"),
  632. ("camembert", "CamembertForCausalLM"),
  633. ("code_llama", "LlamaForCausalLM"),
  634. ("codegen", "CodeGenForCausalLM"),
  635. ("cohere", "CohereForCausalLM"),
  636. ("cohere2", "Cohere2ForCausalLM"),
  637. ("cpmant", "CpmAntForCausalLM"),
  638. ("ctrl", "CTRLLMHeadModel"),
  639. ("data2vec-text", "Data2VecTextForCausalLM"),
  640. ("dbrx", "DbrxForCausalLM"),
  641. ("deepseek_v2", "DeepseekV2ForCausalLM"),
  642. ("deepseek_v3", "DeepseekV3ForCausalLM"),
  643. ("diffllama", "DiffLlamaForCausalLM"),
  644. ("doge", "DogeForCausalLM"),
  645. ("dots1", "Dots1ForCausalLM"),
  646. ("electra", "ElectraForCausalLM"),
  647. ("emu3", "Emu3ForCausalLM"),
  648. ("ernie", "ErnieForCausalLM"),
  649. ("ernie4_5", "Ernie4_5ForCausalLM"),
  650. ("ernie4_5_moe", "Ernie4_5_MoeForCausalLM"),
  651. ("exaone4", "Exaone4ForCausalLM"),
  652. ("falcon", "FalconForCausalLM"),
  653. ("falcon_h1", "FalconH1ForCausalLM"),
  654. ("falcon_mamba", "FalconMambaForCausalLM"),
  655. ("flex_olmo", "FlexOlmoForCausalLM"),
  656. ("fuyu", "FuyuForCausalLM"),
  657. ("gemma", "GemmaForCausalLM"),
  658. ("gemma2", "Gemma2ForCausalLM"),
  659. ("gemma3", "Gemma3ForConditionalGeneration"),
  660. ("gemma3_text", "Gemma3ForCausalLM"),
  661. ("gemma3n", "Gemma3nForConditionalGeneration"),
  662. ("gemma3n_text", "Gemma3nForCausalLM"),
  663. ("git", "GitForCausalLM"),
  664. ("glm", "GlmForCausalLM"),
  665. ("glm4", "Glm4ForCausalLM"),
  666. ("glm4_moe", "Glm4MoeForCausalLM"),
  667. ("got_ocr2", "GotOcr2ForConditionalGeneration"),
  668. ("gpt-sw3", "GPT2LMHeadModel"),
  669. ("gpt2", "GPT2LMHeadModel"),
  670. ("gpt_bigcode", "GPTBigCodeForCausalLM"),
  671. ("gpt_neo", "GPTNeoForCausalLM"),
  672. ("gpt_neox", "GPTNeoXForCausalLM"),
  673. ("gpt_neox_japanese", "GPTNeoXJapaneseForCausalLM"),
  674. ("gpt_oss", "GptOssForCausalLM"),
  675. ("gptj", "GPTJForCausalLM"),
  676. ("granite", "GraniteForCausalLM"),
  677. ("granitemoe", "GraniteMoeForCausalLM"),
  678. ("granitemoehybrid", "GraniteMoeHybridForCausalLM"),
  679. ("granitemoeshared", "GraniteMoeSharedForCausalLM"),
  680. ("helium", "HeliumForCausalLM"),
  681. ("hunyuan_v1_dense", "HunYuanDenseV1ForCausalLM"),
  682. ("hunyuan_v1_moe", "HunYuanMoEV1ForCausalLM"),
  683. ("jamba", "JambaForCausalLM"),
  684. ("jetmoe", "JetMoeForCausalLM"),
  685. ("lfm2", "Lfm2ForCausalLM"),
  686. ("llama", "LlamaForCausalLM"),
  687. ("llama4", "Llama4ForCausalLM"),
  688. ("llama4_text", "Llama4ForCausalLM"),
  689. ("longcat_flash", "LongcatFlashForCausalLM"),
  690. ("mamba", "MambaForCausalLM"),
  691. ("mamba2", "Mamba2ForCausalLM"),
  692. ("marian", "MarianForCausalLM"),
  693. ("mbart", "MBartForCausalLM"),
  694. ("mega", "MegaForCausalLM"),
  695. ("megatron-bert", "MegatronBertForCausalLM"),
  696. ("minimax", "MiniMaxForCausalLM"),
  697. ("ministral", "MinistralForCausalLM"),
  698. ("mistral", "MistralForCausalLM"),
  699. ("mixtral", "MixtralForCausalLM"),
  700. ("mllama", "MllamaForCausalLM"),
  701. ("modernbert-decoder", "ModernBertDecoderForCausalLM"),
  702. ("moshi", "MoshiForCausalLM"),
  703. ("mpt", "MptForCausalLM"),
  704. ("musicgen", "MusicgenForCausalLM"),
  705. ("musicgen_melody", "MusicgenMelodyForCausalLM"),
  706. ("mvp", "MvpForCausalLM"),
  707. ("nemotron", "NemotronForCausalLM"),
  708. ("olmo", "OlmoForCausalLM"),
  709. ("olmo2", "Olmo2ForCausalLM"),
  710. ("olmo3", "Olmo3ForCausalLM"),
  711. ("olmoe", "OlmoeForCausalLM"),
  712. ("open-llama", "OpenLlamaForCausalLM"),
  713. ("openai-gpt", "OpenAIGPTLMHeadModel"),
  714. ("opt", "OPTForCausalLM"),
  715. ("pegasus", "PegasusForCausalLM"),
  716. ("persimmon", "PersimmonForCausalLM"),
  717. ("phi", "PhiForCausalLM"),
  718. ("phi3", "Phi3ForCausalLM"),
  719. ("phi4_multimodal", "Phi4MultimodalForCausalLM"),
  720. ("phimoe", "PhimoeForCausalLM"),
  721. ("plbart", "PLBartForCausalLM"),
  722. ("prophetnet", "ProphetNetForCausalLM"),
  723. ("qdqbert", "QDQBertLMHeadModel"),
  724. ("qwen2", "Qwen2ForCausalLM"),
  725. ("qwen2_moe", "Qwen2MoeForCausalLM"),
  726. ("qwen3", "Qwen3ForCausalLM"),
  727. ("qwen3_moe", "Qwen3MoeForCausalLM"),
  728. ("qwen3_next", "Qwen3NextForCausalLM"),
  729. ("recurrent_gemma", "RecurrentGemmaForCausalLM"),
  730. ("reformer", "ReformerModelWithLMHead"),
  731. ("rembert", "RemBertForCausalLM"),
  732. ("roberta", "RobertaForCausalLM"),
  733. ("roberta-prelayernorm", "RobertaPreLayerNormForCausalLM"),
  734. ("roc_bert", "RoCBertForCausalLM"),
  735. ("roformer", "RoFormerForCausalLM"),
  736. ("rwkv", "RwkvForCausalLM"),
  737. ("seed_oss", "SeedOssForCausalLM"),
  738. ("smollm3", "SmolLM3ForCausalLM"),
  739. ("speech_to_text_2", "Speech2Text2ForCausalLM"),
  740. ("stablelm", "StableLmForCausalLM"),
  741. ("starcoder2", "Starcoder2ForCausalLM"),
  742. ("transfo-xl", "TransfoXLLMHeadModel"),
  743. ("trocr", "TrOCRForCausalLM"),
  744. ("vaultgemma", "VaultGemmaForCausalLM"),
  745. ("whisper", "WhisperForCausalLM"),
  746. ("xglm", "XGLMForCausalLM"),
  747. ("xlm", "XLMWithLMHeadModel"),
  748. ("xlm-prophetnet", "XLMProphetNetForCausalLM"),
  749. ("xlm-roberta", "XLMRobertaForCausalLM"),
  750. ("xlm-roberta-xl", "XLMRobertaXLForCausalLM"),
  751. ("xlnet", "XLNetLMHeadModel"),
  752. ("xlstm", "xLSTMForCausalLM"),
  753. ("xmod", "XmodForCausalLM"),
  754. ("zamba", "ZambaForCausalLM"),
  755. ("zamba2", "Zamba2ForCausalLM"),
  756. ]
  757. )
  758. MODEL_FOR_IMAGE_MAPPING_NAMES = OrderedDict(
  759. [
  760. # Model for Image mapping
  761. ("aimv2_vision_model", "Aimv2VisionModel"),
  762. ("beit", "BeitModel"),
  763. ("bit", "BitModel"),
  764. ("cohere2_vision", "Cohere2VisionModel"),
  765. ("conditional_detr", "ConditionalDetrModel"),
  766. ("convnext", "ConvNextModel"),
  767. ("convnextv2", "ConvNextV2Model"),
  768. ("dab-detr", "DabDetrModel"),
  769. ("data2vec-vision", "Data2VecVisionModel"),
  770. ("deformable_detr", "DeformableDetrModel"),
  771. ("deit", "DeiTModel"),
  772. ("depth_pro", "DepthProModel"),
  773. ("deta", "DetaModel"),
  774. ("detr", "DetrModel"),
  775. ("dinat", "DinatModel"),
  776. ("dinov2", "Dinov2Model"),
  777. ("dinov2_with_registers", "Dinov2WithRegistersModel"),
  778. ("dinov3_convnext", "DINOv3ConvNextModel"),
  779. ("dinov3_vit", "DINOv3ViTModel"),
  780. ("dpt", "DPTModel"),
  781. ("efficientformer", "EfficientFormerModel"),
  782. ("efficientnet", "EfficientNetModel"),
  783. ("focalnet", "FocalNetModel"),
  784. ("glpn", "GLPNModel"),
  785. ("hiera", "HieraModel"),
  786. ("ijepa", "IJepaModel"),
  787. ("imagegpt", "ImageGPTModel"),
  788. ("levit", "LevitModel"),
  789. ("llama4", "Llama4VisionModel"),
  790. ("mlcd", "MLCDVisionModel"),
  791. ("mllama", "MllamaVisionModel"),
  792. ("mobilenet_v1", "MobileNetV1Model"),
  793. ("mobilenet_v2", "MobileNetV2Model"),
  794. ("mobilevit", "MobileViTModel"),
  795. ("mobilevitv2", "MobileViTV2Model"),
  796. ("nat", "NatModel"),
  797. ("poolformer", "PoolFormerModel"),
  798. ("pvt", "PvtModel"),
  799. ("regnet", "RegNetModel"),
  800. ("resnet", "ResNetModel"),
  801. ("segformer", "SegformerModel"),
  802. ("siglip_vision_model", "SiglipVisionModel"),
  803. ("swiftformer", "SwiftFormerModel"),
  804. ("swin", "SwinModel"),
  805. ("swin2sr", "Swin2SRModel"),
  806. ("swinv2", "Swinv2Model"),
  807. ("table-transformer", "TableTransformerModel"),
  808. ("timesformer", "TimesformerModel"),
  809. ("timm_backbone", "TimmBackbone"),
  810. ("timm_wrapper", "TimmWrapperModel"),
  811. ("van", "VanModel"),
  812. ("videomae", "VideoMAEModel"),
  813. ("vit", "ViTModel"),
  814. ("vit_hybrid", "ViTHybridModel"),
  815. ("vit_mae", "ViTMAEModel"),
  816. ("vit_msn", "ViTMSNModel"),
  817. ("vitdet", "VitDetModel"),
  818. ("vivit", "VivitModel"),
  819. ("yolos", "YolosModel"),
  820. ]
  821. )
  822. MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES = OrderedDict(
  823. [
  824. ("deit", "DeiTForMaskedImageModeling"),
  825. ("focalnet", "FocalNetForMaskedImageModeling"),
  826. ("swin", "SwinForMaskedImageModeling"),
  827. ("swinv2", "Swinv2ForMaskedImageModeling"),
  828. ("vit", "ViTForMaskedImageModeling"),
  829. ]
  830. )
  831. MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING_NAMES = OrderedDict(
  832. # Model for Causal Image Modeling mapping
  833. [
  834. ("imagegpt", "ImageGPTForCausalImageModeling"),
  835. ]
  836. )
  837. MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
  838. [
  839. # Model for Image Classification mapping
  840. ("beit", "BeitForImageClassification"),
  841. ("bit", "BitForImageClassification"),
  842. ("clip", "CLIPForImageClassification"),
  843. ("convnext", "ConvNextForImageClassification"),
  844. ("convnextv2", "ConvNextV2ForImageClassification"),
  845. ("cvt", "CvtForImageClassification"),
  846. ("data2vec-vision", "Data2VecVisionForImageClassification"),
  847. (
  848. "deit",
  849. ("DeiTForImageClassification", "DeiTForImageClassificationWithTeacher"),
  850. ),
  851. ("dinat", "DinatForImageClassification"),
  852. ("dinov2", "Dinov2ForImageClassification"),
  853. ("dinov2_with_registers", "Dinov2WithRegistersForImageClassification"),
  854. ("donut-swin", "DonutSwinForImageClassification"),
  855. (
  856. "efficientformer",
  857. (
  858. "EfficientFormerForImageClassification",
  859. "EfficientFormerForImageClassificationWithTeacher",
  860. ),
  861. ),
  862. ("efficientnet", "EfficientNetForImageClassification"),
  863. ("focalnet", "FocalNetForImageClassification"),
  864. ("hgnet_v2", "HGNetV2ForImageClassification"),
  865. ("hiera", "HieraForImageClassification"),
  866. ("ijepa", "IJepaForImageClassification"),
  867. ("imagegpt", "ImageGPTForImageClassification"),
  868. (
  869. "levit",
  870. ("LevitForImageClassification", "LevitForImageClassificationWithTeacher"),
  871. ),
  872. ("metaclip_2", "MetaClip2ForImageClassification"),
  873. ("mobilenet_v1", "MobileNetV1ForImageClassification"),
  874. ("mobilenet_v2", "MobileNetV2ForImageClassification"),
  875. ("mobilevit", "MobileViTForImageClassification"),
  876. ("mobilevitv2", "MobileViTV2ForImageClassification"),
  877. ("nat", "NatForImageClassification"),
  878. (
  879. "perceiver",
  880. (
  881. "PerceiverForImageClassificationLearned",
  882. "PerceiverForImageClassificationFourier",
  883. "PerceiverForImageClassificationConvProcessing",
  884. ),
  885. ),
  886. ("poolformer", "PoolFormerForImageClassification"),
  887. ("pvt", "PvtForImageClassification"),
  888. ("pvt_v2", "PvtV2ForImageClassification"),
  889. ("regnet", "RegNetForImageClassification"),
  890. ("resnet", "ResNetForImageClassification"),
  891. ("segformer", "SegformerForImageClassification"),
  892. ("shieldgemma2", "ShieldGemma2ForImageClassification"),
  893. ("siglip", "SiglipForImageClassification"),
  894. ("siglip2", "Siglip2ForImageClassification"),
  895. ("swiftformer", "SwiftFormerForImageClassification"),
  896. ("swin", "SwinForImageClassification"),
  897. ("swinv2", "Swinv2ForImageClassification"),
  898. ("textnet", "TextNetForImageClassification"),
  899. ("timm_wrapper", "TimmWrapperForImageClassification"),
  900. ("van", "VanForImageClassification"),
  901. ("vit", "ViTForImageClassification"),
  902. ("vit_hybrid", "ViTHybridForImageClassification"),
  903. ("vit_msn", "ViTMSNForImageClassification"),
  904. ]
  905. )
  906. MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES = OrderedDict(
  907. [
  908. # Do not add new models here, this class will be deprecated in the future.
  909. # Model for Image Segmentation mapping
  910. ("detr", "DetrForSegmentation"),
  911. ]
  912. )
  913. MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = OrderedDict(
  914. [
  915. # Model for Semantic Segmentation mapping
  916. ("beit", "BeitForSemanticSegmentation"),
  917. ("data2vec-vision", "Data2VecVisionForSemanticSegmentation"),
  918. ("dpt", "DPTForSemanticSegmentation"),
  919. ("mobilenet_v2", "MobileNetV2ForSemanticSegmentation"),
  920. ("mobilevit", "MobileViTForSemanticSegmentation"),
  921. ("mobilevitv2", "MobileViTV2ForSemanticSegmentation"),
  922. ("segformer", "SegformerForSemanticSegmentation"),
  923. ("upernet", "UperNetForSemanticSegmentation"),
  924. ]
  925. )
  926. MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING_NAMES = OrderedDict(
  927. [
  928. # Model for Instance Segmentation mapping
  929. # MaskFormerForInstanceSegmentation can be removed from this mapping in v5
  930. ("maskformer", "MaskFormerForInstanceSegmentation"),
  931. ]
  932. )
  933. MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES = OrderedDict(
  934. [
  935. # Model for Universal Segmentation mapping
  936. ("detr", "DetrForSegmentation"),
  937. ("eomt", "EomtForUniversalSegmentation"),
  938. ("mask2former", "Mask2FormerForUniversalSegmentation"),
  939. ("maskformer", "MaskFormerForInstanceSegmentation"),
  940. ("oneformer", "OneFormerForUniversalSegmentation"),
  941. ]
  942. )
  943. MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
  944. [
  945. ("timesformer", "TimesformerForVideoClassification"),
  946. ("videomae", "VideoMAEForVideoClassification"),
  947. ("vivit", "VivitForVideoClassification"),
  948. ("vjepa2", "VJEPA2ForVideoClassification"),
  949. ]
  950. )
  951. MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
  952. [
  953. ("blip", "BlipForConditionalGeneration"),
  954. ("blip-2", "Blip2ForConditionalGeneration"),
  955. ("chameleon", "ChameleonForConditionalGeneration"),
  956. ("git", "GitForCausalLM"),
  957. ("idefics2", "Idefics2ForConditionalGeneration"),
  958. ("idefics3", "Idefics3ForConditionalGeneration"),
  959. ("instructblip", "InstructBlipForConditionalGeneration"),
  960. ("instructblipvideo", "InstructBlipVideoForConditionalGeneration"),
  961. ("kosmos-2", "Kosmos2ForConditionalGeneration"),
  962. ("kosmos-2.5", "Kosmos2_5ForConditionalGeneration"),
  963. ("llava", "LlavaForConditionalGeneration"),
  964. ("llava_next", "LlavaNextForConditionalGeneration"),
  965. ("llava_next_video", "LlavaNextVideoForConditionalGeneration"),
  966. ("llava_onevision", "LlavaOnevisionForConditionalGeneration"),
  967. ("mistral3", "Mistral3ForConditionalGeneration"),
  968. ("mllama", "MllamaForConditionalGeneration"),
  969. ("ovis2", "Ovis2ForConditionalGeneration"),
  970. ("paligemma", "PaliGemmaForConditionalGeneration"),
  971. ("pix2struct", "Pix2StructForConditionalGeneration"),
  972. ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"),
  973. ("qwen2_vl", "Qwen2VLForConditionalGeneration"),
  974. ("qwen3_vl", "Qwen3VLForConditionalGeneration"),
  975. ("qwen3_vl_moe", "Qwen3VLMoeForConditionalGeneration"),
  976. ("video_llava", "VideoLlavaForConditionalGeneration"),
  977. ("vipllava", "VipLlavaForConditionalGeneration"),
  978. ("vision-encoder-decoder", "VisionEncoderDecoderModel"),
  979. ]
  980. )
  981. MODEL_FOR_RETRIEVAL_MAPPING_NAMES = OrderedDict(
  982. [
  983. ("colpali", "ColPaliForRetrieval"),
  984. ]
  985. )
  986. MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = OrderedDict(
  987. [
  988. ("aria", "AriaForConditionalGeneration"),
  989. ("aya_vision", "AyaVisionForConditionalGeneration"),
  990. ("blip", "BlipForConditionalGeneration"),
  991. ("blip-2", "Blip2ForConditionalGeneration"),
  992. ("chameleon", "ChameleonForConditionalGeneration"),
  993. ("cohere2_vision", "Cohere2VisionForConditionalGeneration"),
  994. ("deepseek_vl", "DeepseekVLForConditionalGeneration"),
  995. ("deepseek_vl_hybrid", "DeepseekVLHybridForConditionalGeneration"),
  996. ("emu3", "Emu3ForConditionalGeneration"),
  997. ("evolla", "EvollaForProteinText2Text"),
  998. ("florence2", "Florence2ForConditionalGeneration"),
  999. ("fuyu", "FuyuForCausalLM"),
  1000. ("gemma3", "Gemma3ForConditionalGeneration"),
  1001. ("gemma3n", "Gemma3nForConditionalGeneration"),
  1002. ("git", "GitForCausalLM"),
  1003. ("glm4v", "Glm4vForConditionalGeneration"),
  1004. ("glm4v_moe", "Glm4vMoeForConditionalGeneration"),
  1005. ("got_ocr2", "GotOcr2ForConditionalGeneration"),
  1006. ("idefics", "IdeficsForVisionText2Text"),
  1007. ("idefics2", "Idefics2ForConditionalGeneration"),
  1008. ("idefics3", "Idefics3ForConditionalGeneration"),
  1009. ("instructblip", "InstructBlipForConditionalGeneration"),
  1010. ("internvl", "InternVLForConditionalGeneration"),
  1011. ("janus", "JanusForConditionalGeneration"),
  1012. ("kosmos-2", "Kosmos2ForConditionalGeneration"),
  1013. ("kosmos-2.5", "Kosmos2_5ForConditionalGeneration"),
  1014. ("lfm2_vl", "Lfm2VlForConditionalGeneration"),
  1015. ("llama4", "Llama4ForConditionalGeneration"),
  1016. ("llava", "LlavaForConditionalGeneration"),
  1017. ("llava_next", "LlavaNextForConditionalGeneration"),
  1018. ("llava_next_video", "LlavaNextVideoForConditionalGeneration"),
  1019. ("llava_onevision", "LlavaOnevisionForConditionalGeneration"),
  1020. ("mistral3", "Mistral3ForConditionalGeneration"),
  1021. ("mllama", "MllamaForConditionalGeneration"),
  1022. ("ovis2", "Ovis2ForConditionalGeneration"),
  1023. ("paligemma", "PaliGemmaForConditionalGeneration"),
  1024. ("perception_lm", "PerceptionLMForConditionalGeneration"),
  1025. ("pix2struct", "Pix2StructForConditionalGeneration"),
  1026. ("pixtral", "LlavaForConditionalGeneration"),
  1027. ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"),
  1028. ("qwen2_vl", "Qwen2VLForConditionalGeneration"),
  1029. ("qwen3_vl", "Qwen3VLForConditionalGeneration"),
  1030. ("qwen3_vl_moe", "Qwen3VLMoeForConditionalGeneration"),
  1031. ("shieldgemma2", "Gemma3ForConditionalGeneration"),
  1032. ("smolvlm", "SmolVLMForConditionalGeneration"),
  1033. ("udop", "UdopForConditionalGeneration"),
  1034. ("vipllava", "VipLlavaForConditionalGeneration"),
  1035. ("vision-encoder-decoder", "VisionEncoderDecoderModel"),
  1036. ]
  1037. )
  1038. MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
  1039. [
  1040. # Model for Masked LM mapping
  1041. ("albert", "AlbertForMaskedLM"),
  1042. ("bart", "BartForConditionalGeneration"),
  1043. ("bert", "BertForMaskedLM"),
  1044. ("big_bird", "BigBirdForMaskedLM"),
  1045. ("camembert", "CamembertForMaskedLM"),
  1046. ("convbert", "ConvBertForMaskedLM"),
  1047. ("data2vec-text", "Data2VecTextForMaskedLM"),
  1048. ("deberta", "DebertaForMaskedLM"),
  1049. ("deberta-v2", "DebertaV2ForMaskedLM"),
  1050. ("distilbert", "DistilBertForMaskedLM"),
  1051. ("electra", "ElectraForMaskedLM"),
  1052. ("ernie", "ErnieForMaskedLM"),
  1053. ("esm", "EsmForMaskedLM"),
  1054. ("flaubert", "FlaubertWithLMHeadModel"),
  1055. ("fnet", "FNetForMaskedLM"),
  1056. ("funnel", "FunnelForMaskedLM"),
  1057. ("ibert", "IBertForMaskedLM"),
  1058. ("layoutlm", "LayoutLMForMaskedLM"),
  1059. ("longformer", "LongformerForMaskedLM"),
  1060. ("luke", "LukeForMaskedLM"),
  1061. ("mbart", "MBartForConditionalGeneration"),
  1062. ("mega", "MegaForMaskedLM"),
  1063. ("megatron-bert", "MegatronBertForMaskedLM"),
  1064. ("mobilebert", "MobileBertForMaskedLM"),
  1065. ("modernbert", "ModernBertForMaskedLM"),
  1066. ("mpnet", "MPNetForMaskedLM"),
  1067. ("mra", "MraForMaskedLM"),
  1068. ("mvp", "MvpForConditionalGeneration"),
  1069. ("nezha", "NezhaForMaskedLM"),
  1070. ("nystromformer", "NystromformerForMaskedLM"),
  1071. ("perceiver", "PerceiverForMaskedLM"),
  1072. ("qdqbert", "QDQBertForMaskedLM"),
  1073. ("reformer", "ReformerForMaskedLM"),
  1074. ("rembert", "RemBertForMaskedLM"),
  1075. ("roberta", "RobertaForMaskedLM"),
  1076. ("roberta-prelayernorm", "RobertaPreLayerNormForMaskedLM"),
  1077. ("roc_bert", "RoCBertForMaskedLM"),
  1078. ("roformer", "RoFormerForMaskedLM"),
  1079. ("squeezebert", "SqueezeBertForMaskedLM"),
  1080. ("tapas", "TapasForMaskedLM"),
  1081. ("wav2vec2", "Wav2Vec2ForMaskedLM"),
  1082. ("xlm", "XLMWithLMHeadModel"),
  1083. ("xlm-roberta", "XLMRobertaForMaskedLM"),
  1084. ("xlm-roberta-xl", "XLMRobertaXLForMaskedLM"),
  1085. ("xmod", "XmodForMaskedLM"),
  1086. ("yoso", "YosoForMaskedLM"),
  1087. ]
  1088. )
  1089. MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict(
  1090. [
  1091. # Model for Object Detection mapping
  1092. ("conditional_detr", "ConditionalDetrForObjectDetection"),
  1093. ("d_fine", "DFineForObjectDetection"),
  1094. ("dab-detr", "DabDetrForObjectDetection"),
  1095. ("deformable_detr", "DeformableDetrForObjectDetection"),
  1096. ("deta", "DetaForObjectDetection"),
  1097. ("detr", "DetrForObjectDetection"),
  1098. ("rt_detr", "RTDetrForObjectDetection"),
  1099. ("rt_detr_v2", "RTDetrV2ForObjectDetection"),
  1100. ("table-transformer", "TableTransformerForObjectDetection"),
  1101. ("yolos", "YolosForObjectDetection"),
  1102. ]
  1103. )
  1104. MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict(
  1105. [
  1106. # Model for Zero Shot Object Detection mapping
  1107. ("grounding-dino", "GroundingDinoForObjectDetection"),
  1108. ("mm-grounding-dino", "MMGroundingDinoForObjectDetection"),
  1109. ("omdet-turbo", "OmDetTurboForObjectDetection"),
  1110. ("owlv2", "Owlv2ForObjectDetection"),
  1111. ("owlvit", "OwlViTForObjectDetection"),
  1112. ]
  1113. )
  1114. MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES = OrderedDict(
  1115. [
  1116. # Model for depth estimation mapping
  1117. ("depth_anything", "DepthAnythingForDepthEstimation"),
  1118. ("depth_pro", "DepthProForDepthEstimation"),
  1119. ("dpt", "DPTForDepthEstimation"),
  1120. ("glpn", "GLPNForDepthEstimation"),
  1121. ("prompt_depth_anything", "PromptDepthAnythingForDepthEstimation"),
  1122. ("zoedepth", "ZoeDepthForDepthEstimation"),
  1123. ]
  1124. )
  1125. MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
  1126. [
  1127. # Model for Seq2Seq Causal LM mapping
  1128. ("bart", "BartForConditionalGeneration"),
  1129. ("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"),
  1130. ("blenderbot", "BlenderbotForConditionalGeneration"),
  1131. ("blenderbot-small", "BlenderbotSmallForConditionalGeneration"),
  1132. ("encoder-decoder", "EncoderDecoderModel"),
  1133. ("fsmt", "FSMTForConditionalGeneration"),
  1134. ("gptsan-japanese", "GPTSanJapaneseForConditionalGeneration"),
  1135. ("granite_speech", "GraniteSpeechForConditionalGeneration"),
  1136. ("led", "LEDForConditionalGeneration"),
  1137. ("longt5", "LongT5ForConditionalGeneration"),
  1138. ("m2m_100", "M2M100ForConditionalGeneration"),
  1139. ("marian", "MarianMTModel"),
  1140. ("mbart", "MBartForConditionalGeneration"),
  1141. ("mt5", "MT5ForConditionalGeneration"),
  1142. ("mvp", "MvpForConditionalGeneration"),
  1143. ("nllb-moe", "NllbMoeForConditionalGeneration"),
  1144. ("pegasus", "PegasusForConditionalGeneration"),
  1145. ("pegasus_x", "PegasusXForConditionalGeneration"),
  1146. ("plbart", "PLBartForConditionalGeneration"),
  1147. ("prophetnet", "ProphetNetForConditionalGeneration"),
  1148. ("qwen2_audio", "Qwen2AudioForConditionalGeneration"),
  1149. ("seamless_m4t", "SeamlessM4TForTextToText"),
  1150. ("seamless_m4t_v2", "SeamlessM4Tv2ForTextToText"),
  1151. ("switch_transformers", "SwitchTransformersForConditionalGeneration"),
  1152. ("t5", "T5ForConditionalGeneration"),
  1153. ("t5gemma", "T5GemmaForConditionalGeneration"),
  1154. ("umt5", "UMT5ForConditionalGeneration"),
  1155. ("voxtral", "VoxtralForConditionalGeneration"),
  1156. ("xlm-prophetnet", "XLMProphetNetForConditionalGeneration"),
  1157. ]
  1158. )
  1159. MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
  1160. [
  1161. ("dia", "DiaForConditionalGeneration"),
  1162. ("granite_speech", "GraniteSpeechForConditionalGeneration"),
  1163. ("kyutai_speech_to_text", "KyutaiSpeechToTextForConditionalGeneration"),
  1164. ("moonshine", "MoonshineForConditionalGeneration"),
  1165. ("pop2piano", "Pop2PianoForConditionalGeneration"),
  1166. ("seamless_m4t", "SeamlessM4TForSpeechToText"),
  1167. ("seamless_m4t_v2", "SeamlessM4Tv2ForSpeechToText"),
  1168. ("speech-encoder-decoder", "SpeechEncoderDecoderModel"),
  1169. ("speech_to_text", "Speech2TextForConditionalGeneration"),
  1170. ("speecht5", "SpeechT5ForSpeechToText"),
  1171. ("whisper", "WhisperForConditionalGeneration"),
  1172. ]
  1173. )
  1174. MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
  1175. [
  1176. # Model for Sequence Classification mapping
  1177. ("albert", "AlbertForSequenceClassification"),
  1178. ("arcee", "ArceeForSequenceClassification"),
  1179. ("bart", "BartForSequenceClassification"),
  1180. ("bert", "BertForSequenceClassification"),
  1181. ("big_bird", "BigBirdForSequenceClassification"),
  1182. ("bigbird_pegasus", "BigBirdPegasusForSequenceClassification"),
  1183. ("biogpt", "BioGptForSequenceClassification"),
  1184. ("bloom", "BloomForSequenceClassification"),
  1185. ("camembert", "CamembertForSequenceClassification"),
  1186. ("canine", "CanineForSequenceClassification"),
  1187. ("code_llama", "LlamaForSequenceClassification"),
  1188. ("convbert", "ConvBertForSequenceClassification"),
  1189. ("ctrl", "CTRLForSequenceClassification"),
  1190. ("data2vec-text", "Data2VecTextForSequenceClassification"),
  1191. ("deberta", "DebertaForSequenceClassification"),
  1192. ("deberta-v2", "DebertaV2ForSequenceClassification"),
  1193. ("deepseek_v2", "DeepseekV2ForSequenceClassification"),
  1194. ("deepseek_v3", "DeepseekV3ForSequenceClassification"),
  1195. ("diffllama", "DiffLlamaForSequenceClassification"),
  1196. ("distilbert", "DistilBertForSequenceClassification"),
  1197. ("doge", "DogeForSequenceClassification"),
  1198. ("electra", "ElectraForSequenceClassification"),
  1199. ("ernie", "ErnieForSequenceClassification"),
  1200. ("ernie_m", "ErnieMForSequenceClassification"),
  1201. ("esm", "EsmForSequenceClassification"),
  1202. ("exaone4", "Exaone4ForSequenceClassification"),
  1203. ("falcon", "FalconForSequenceClassification"),
  1204. ("flaubert", "FlaubertForSequenceClassification"),
  1205. ("fnet", "FNetForSequenceClassification"),
  1206. ("funnel", "FunnelForSequenceClassification"),
  1207. ("gemma", "GemmaForSequenceClassification"),
  1208. ("gemma2", "Gemma2ForSequenceClassification"),
  1209. ("gemma3", "Gemma3ForSequenceClassification"),
  1210. ("gemma3_text", "Gemma3TextForSequenceClassification"),
  1211. ("glm", "GlmForSequenceClassification"),
  1212. ("glm4", "Glm4ForSequenceClassification"),
  1213. ("gpt-sw3", "GPT2ForSequenceClassification"),
  1214. ("gpt2", "GPT2ForSequenceClassification"),
  1215. ("gpt_bigcode", "GPTBigCodeForSequenceClassification"),
  1216. ("gpt_neo", "GPTNeoForSequenceClassification"),
  1217. ("gpt_neox", "GPTNeoXForSequenceClassification"),
  1218. ("gpt_oss", "GptOssForSequenceClassification"),
  1219. ("gptj", "GPTJForSequenceClassification"),
  1220. ("helium", "HeliumForSequenceClassification"),
  1221. ("hunyuan_v1_dense", "HunYuanDenseV1ForSequenceClassification"),
  1222. ("hunyuan_v1_moe", "HunYuanMoEV1ForSequenceClassification"),
  1223. ("ibert", "IBertForSequenceClassification"),
  1224. ("jamba", "JambaForSequenceClassification"),
  1225. ("jetmoe", "JetMoeForSequenceClassification"),
  1226. ("layoutlm", "LayoutLMForSequenceClassification"),
  1227. ("layoutlmv2", "LayoutLMv2ForSequenceClassification"),
  1228. ("layoutlmv3", "LayoutLMv3ForSequenceClassification"),
  1229. ("led", "LEDForSequenceClassification"),
  1230. ("lilt", "LiltForSequenceClassification"),
  1231. ("llama", "LlamaForSequenceClassification"),
  1232. ("longformer", "LongformerForSequenceClassification"),
  1233. ("luke", "LukeForSequenceClassification"),
  1234. ("markuplm", "MarkupLMForSequenceClassification"),
  1235. ("mbart", "MBartForSequenceClassification"),
  1236. ("mega", "MegaForSequenceClassification"),
  1237. ("megatron-bert", "MegatronBertForSequenceClassification"),
  1238. ("minimax", "MiniMaxForSequenceClassification"),
  1239. ("ministral", "MinistralForSequenceClassification"),
  1240. ("mistral", "MistralForSequenceClassification"),
  1241. ("mixtral", "MixtralForSequenceClassification"),
  1242. ("mobilebert", "MobileBertForSequenceClassification"),
  1243. ("modernbert", "ModernBertForSequenceClassification"),
  1244. ("modernbert-decoder", "ModernBertDecoderForSequenceClassification"),
  1245. ("mpnet", "MPNetForSequenceClassification"),
  1246. ("mpt", "MptForSequenceClassification"),
  1247. ("mra", "MraForSequenceClassification"),
  1248. ("mt5", "MT5ForSequenceClassification"),
  1249. ("mvp", "MvpForSequenceClassification"),
  1250. ("nemotron", "NemotronForSequenceClassification"),
  1251. ("nezha", "NezhaForSequenceClassification"),
  1252. ("nystromformer", "NystromformerForSequenceClassification"),
  1253. ("open-llama", "OpenLlamaForSequenceClassification"),
  1254. ("openai-gpt", "OpenAIGPTForSequenceClassification"),
  1255. ("opt", "OPTForSequenceClassification"),
  1256. ("perceiver", "PerceiverForSequenceClassification"),
  1257. ("persimmon", "PersimmonForSequenceClassification"),
  1258. ("phi", "PhiForSequenceClassification"),
  1259. ("phi3", "Phi3ForSequenceClassification"),
  1260. ("phimoe", "PhimoeForSequenceClassification"),
  1261. ("plbart", "PLBartForSequenceClassification"),
  1262. ("qdqbert", "QDQBertForSequenceClassification"),
  1263. ("qwen2", "Qwen2ForSequenceClassification"),
  1264. ("qwen2_moe", "Qwen2MoeForSequenceClassification"),
  1265. ("qwen3", "Qwen3ForSequenceClassification"),
  1266. ("qwen3_moe", "Qwen3MoeForSequenceClassification"),
  1267. ("qwen3_next", "Qwen3NextForSequenceClassification"),
  1268. ("reformer", "ReformerForSequenceClassification"),
  1269. ("rembert", "RemBertForSequenceClassification"),
  1270. ("roberta", "RobertaForSequenceClassification"),
  1271. ("roberta-prelayernorm", "RobertaPreLayerNormForSequenceClassification"),
  1272. ("roc_bert", "RoCBertForSequenceClassification"),
  1273. ("roformer", "RoFormerForSequenceClassification"),
  1274. ("seed_oss", "SeedOssForSequenceClassification"),
  1275. ("smollm3", "SmolLM3ForSequenceClassification"),
  1276. ("squeezebert", "SqueezeBertForSequenceClassification"),
  1277. ("stablelm", "StableLmForSequenceClassification"),
  1278. ("starcoder2", "Starcoder2ForSequenceClassification"),
  1279. ("t5", "T5ForSequenceClassification"),
  1280. ("t5gemma", "T5GemmaForSequenceClassification"),
  1281. ("tapas", "TapasForSequenceClassification"),
  1282. ("transfo-xl", "TransfoXLForSequenceClassification"),
  1283. ("umt5", "UMT5ForSequenceClassification"),
  1284. ("xlm", "XLMForSequenceClassification"),
  1285. ("xlm-roberta", "XLMRobertaForSequenceClassification"),
  1286. ("xlm-roberta-xl", "XLMRobertaXLForSequenceClassification"),
  1287. ("xlnet", "XLNetForSequenceClassification"),
  1288. ("xmod", "XmodForSequenceClassification"),
  1289. ("yoso", "YosoForSequenceClassification"),
  1290. ("zamba", "ZambaForSequenceClassification"),
  1291. ("zamba2", "Zamba2ForSequenceClassification"),
  1292. ]
  1293. )
  1294. MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
  1295. [
  1296. # Model for Question Answering mapping
  1297. ("albert", "AlbertForQuestionAnswering"),
  1298. ("arcee", "ArceeForQuestionAnswering"),
  1299. ("bart", "BartForQuestionAnswering"),
  1300. ("bert", "BertForQuestionAnswering"),
  1301. ("big_bird", "BigBirdForQuestionAnswering"),
  1302. ("bigbird_pegasus", "BigBirdPegasusForQuestionAnswering"),
  1303. ("bloom", "BloomForQuestionAnswering"),
  1304. ("camembert", "CamembertForQuestionAnswering"),
  1305. ("canine", "CanineForQuestionAnswering"),
  1306. ("convbert", "ConvBertForQuestionAnswering"),
  1307. ("data2vec-text", "Data2VecTextForQuestionAnswering"),
  1308. ("deberta", "DebertaForQuestionAnswering"),
  1309. ("deberta-v2", "DebertaV2ForQuestionAnswering"),
  1310. ("diffllama", "DiffLlamaForQuestionAnswering"),
  1311. ("distilbert", "DistilBertForQuestionAnswering"),
  1312. ("electra", "ElectraForQuestionAnswering"),
  1313. ("ernie", "ErnieForQuestionAnswering"),
  1314. ("ernie_m", "ErnieMForQuestionAnswering"),
  1315. ("exaone4", "Exaone4ForQuestionAnswering"),
  1316. ("falcon", "FalconForQuestionAnswering"),
  1317. ("flaubert", "FlaubertForQuestionAnsweringSimple"),
  1318. ("fnet", "FNetForQuestionAnswering"),
  1319. ("funnel", "FunnelForQuestionAnswering"),
  1320. ("gpt2", "GPT2ForQuestionAnswering"),
  1321. ("gpt_neo", "GPTNeoForQuestionAnswering"),
  1322. ("gpt_neox", "GPTNeoXForQuestionAnswering"),
  1323. ("gptj", "GPTJForQuestionAnswering"),
  1324. ("ibert", "IBertForQuestionAnswering"),
  1325. ("layoutlmv2", "LayoutLMv2ForQuestionAnswering"),
  1326. ("layoutlmv3", "LayoutLMv3ForQuestionAnswering"),
  1327. ("led", "LEDForQuestionAnswering"),
  1328. ("lilt", "LiltForQuestionAnswering"),
  1329. ("llama", "LlamaForQuestionAnswering"),
  1330. ("longformer", "LongformerForQuestionAnswering"),
  1331. ("luke", "LukeForQuestionAnswering"),
  1332. ("lxmert", "LxmertForQuestionAnswering"),
  1333. ("markuplm", "MarkupLMForQuestionAnswering"),
  1334. ("mbart", "MBartForQuestionAnswering"),
  1335. ("mega", "MegaForQuestionAnswering"),
  1336. ("megatron-bert", "MegatronBertForQuestionAnswering"),
  1337. ("minimax", "MiniMaxForQuestionAnswering"),
  1338. ("ministral", "MinistralForQuestionAnswering"),
  1339. ("mistral", "MistralForQuestionAnswering"),
  1340. ("mixtral", "MixtralForQuestionAnswering"),
  1341. ("mobilebert", "MobileBertForQuestionAnswering"),
  1342. ("modernbert", "ModernBertForQuestionAnswering"),
  1343. ("mpnet", "MPNetForQuestionAnswering"),
  1344. ("mpt", "MptForQuestionAnswering"),
  1345. ("mra", "MraForQuestionAnswering"),
  1346. ("mt5", "MT5ForQuestionAnswering"),
  1347. ("mvp", "MvpForQuestionAnswering"),
  1348. ("nemotron", "NemotronForQuestionAnswering"),
  1349. ("nezha", "NezhaForQuestionAnswering"),
  1350. ("nystromformer", "NystromformerForQuestionAnswering"),
  1351. ("opt", "OPTForQuestionAnswering"),
  1352. ("qdqbert", "QDQBertForQuestionAnswering"),
  1353. ("qwen2", "Qwen2ForQuestionAnswering"),
  1354. ("qwen2_moe", "Qwen2MoeForQuestionAnswering"),
  1355. ("qwen3", "Qwen3ForQuestionAnswering"),
  1356. ("qwen3_moe", "Qwen3MoeForQuestionAnswering"),
  1357. ("qwen3_next", "Qwen3NextForQuestionAnswering"),
  1358. ("reformer", "ReformerForQuestionAnswering"),
  1359. ("rembert", "RemBertForQuestionAnswering"),
  1360. ("roberta", "RobertaForQuestionAnswering"),
  1361. ("roberta-prelayernorm", "RobertaPreLayerNormForQuestionAnswering"),
  1362. ("roc_bert", "RoCBertForQuestionAnswering"),
  1363. ("roformer", "RoFormerForQuestionAnswering"),
  1364. ("seed_oss", "SeedOssForQuestionAnswering"),
  1365. ("smollm3", "SmolLM3ForQuestionAnswering"),
  1366. ("splinter", "SplinterForQuestionAnswering"),
  1367. ("squeezebert", "SqueezeBertForQuestionAnswering"),
  1368. ("t5", "T5ForQuestionAnswering"),
  1369. ("umt5", "UMT5ForQuestionAnswering"),
  1370. ("xlm", "XLMForQuestionAnsweringSimple"),
  1371. ("xlm-roberta", "XLMRobertaForQuestionAnswering"),
  1372. ("xlm-roberta-xl", "XLMRobertaXLForQuestionAnswering"),
  1373. ("xlnet", "XLNetForQuestionAnsweringSimple"),
  1374. ("xmod", "XmodForQuestionAnswering"),
  1375. ("yoso", "YosoForQuestionAnswering"),
  1376. ]
  1377. )
  1378. MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
  1379. [
  1380. # Model for Table Question Answering mapping
  1381. ("tapas", "TapasForQuestionAnswering"),
  1382. ]
  1383. )
  1384. MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
  1385. [
  1386. ("blip", "BlipForQuestionAnswering"),
  1387. ("blip-2", "Blip2ForConditionalGeneration"),
  1388. ("vilt", "ViltForQuestionAnswering"),
  1389. ]
  1390. )
  1391. MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
  1392. [
  1393. ("layoutlm", "LayoutLMForQuestionAnswering"),
  1394. ("layoutlmv2", "LayoutLMv2ForQuestionAnswering"),
  1395. ("layoutlmv3", "LayoutLMv3ForQuestionAnswering"),
  1396. ]
  1397. )
  1398. MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
  1399. [
  1400. # Model for Token Classification mapping
  1401. ("albert", "AlbertForTokenClassification"),
  1402. ("apertus", "ApertusForTokenClassification"),
  1403. ("arcee", "ArceeForTokenClassification"),
  1404. ("bert", "BertForTokenClassification"),
  1405. ("big_bird", "BigBirdForTokenClassification"),
  1406. ("biogpt", "BioGptForTokenClassification"),
  1407. ("bloom", "BloomForTokenClassification"),
  1408. ("bros", "BrosForTokenClassification"),
  1409. ("camembert", "CamembertForTokenClassification"),
  1410. ("canine", "CanineForTokenClassification"),
  1411. ("convbert", "ConvBertForTokenClassification"),
  1412. ("data2vec-text", "Data2VecTextForTokenClassification"),
  1413. ("deberta", "DebertaForTokenClassification"),
  1414. ("deberta-v2", "DebertaV2ForTokenClassification"),
  1415. ("deepseek_v3", "DeepseekV3ForTokenClassification"),
  1416. ("diffllama", "DiffLlamaForTokenClassification"),
  1417. ("distilbert", "DistilBertForTokenClassification"),
  1418. ("electra", "ElectraForTokenClassification"),
  1419. ("ernie", "ErnieForTokenClassification"),
  1420. ("ernie_m", "ErnieMForTokenClassification"),
  1421. ("esm", "EsmForTokenClassification"),
  1422. ("exaone4", "Exaone4ForTokenClassification"),
  1423. ("falcon", "FalconForTokenClassification"),
  1424. ("flaubert", "FlaubertForTokenClassification"),
  1425. ("fnet", "FNetForTokenClassification"),
  1426. ("funnel", "FunnelForTokenClassification"),
  1427. ("gemma", "GemmaForTokenClassification"),
  1428. ("gemma2", "Gemma2ForTokenClassification"),
  1429. ("glm", "GlmForTokenClassification"),
  1430. ("glm4", "Glm4ForTokenClassification"),
  1431. ("gpt-sw3", "GPT2ForTokenClassification"),
  1432. ("gpt2", "GPT2ForTokenClassification"),
  1433. ("gpt_bigcode", "GPTBigCodeForTokenClassification"),
  1434. ("gpt_neo", "GPTNeoForTokenClassification"),
  1435. ("gpt_neox", "GPTNeoXForTokenClassification"),
  1436. ("gpt_oss", "GptOssForTokenClassification"),
  1437. ("helium", "HeliumForTokenClassification"),
  1438. ("ibert", "IBertForTokenClassification"),
  1439. ("layoutlm", "LayoutLMForTokenClassification"),
  1440. ("layoutlmv2", "LayoutLMv2ForTokenClassification"),
  1441. ("layoutlmv3", "LayoutLMv3ForTokenClassification"),
  1442. ("lilt", "LiltForTokenClassification"),
  1443. ("llama", "LlamaForTokenClassification"),
  1444. ("longformer", "LongformerForTokenClassification"),
  1445. ("luke", "LukeForTokenClassification"),
  1446. ("markuplm", "MarkupLMForTokenClassification"),
  1447. ("mega", "MegaForTokenClassification"),
  1448. ("megatron-bert", "MegatronBertForTokenClassification"),
  1449. ("minimax", "MiniMaxForTokenClassification"),
  1450. ("ministral", "MinistralForTokenClassification"),
  1451. ("mistral", "MistralForTokenClassification"),
  1452. ("mixtral", "MixtralForTokenClassification"),
  1453. ("mobilebert", "MobileBertForTokenClassification"),
  1454. ("modernbert", "ModernBertForTokenClassification"),
  1455. ("mpnet", "MPNetForTokenClassification"),
  1456. ("mpt", "MptForTokenClassification"),
  1457. ("mra", "MraForTokenClassification"),
  1458. ("mt5", "MT5ForTokenClassification"),
  1459. ("nemotron", "NemotronForTokenClassification"),
  1460. ("nezha", "NezhaForTokenClassification"),
  1461. ("nystromformer", "NystromformerForTokenClassification"),
  1462. ("persimmon", "PersimmonForTokenClassification"),
  1463. ("phi", "PhiForTokenClassification"),
  1464. ("phi3", "Phi3ForTokenClassification"),
  1465. ("qdqbert", "QDQBertForTokenClassification"),
  1466. ("qwen2", "Qwen2ForTokenClassification"),
  1467. ("qwen2_moe", "Qwen2MoeForTokenClassification"),
  1468. ("qwen3", "Qwen3ForTokenClassification"),
  1469. ("qwen3_moe", "Qwen3MoeForTokenClassification"),
  1470. ("qwen3_next", "Qwen3NextForTokenClassification"),
  1471. ("rembert", "RemBertForTokenClassification"),
  1472. ("roberta", "RobertaForTokenClassification"),
  1473. ("roberta-prelayernorm", "RobertaPreLayerNormForTokenClassification"),
  1474. ("roc_bert", "RoCBertForTokenClassification"),
  1475. ("roformer", "RoFormerForTokenClassification"),
  1476. ("seed_oss", "SeedOssForTokenClassification"),
  1477. ("smollm3", "SmolLM3ForTokenClassification"),
  1478. ("squeezebert", "SqueezeBertForTokenClassification"),
  1479. ("stablelm", "StableLmForTokenClassification"),
  1480. ("starcoder2", "Starcoder2ForTokenClassification"),
  1481. ("t5", "T5ForTokenClassification"),
  1482. ("t5gemma", "T5GemmaForTokenClassification"),
  1483. ("umt5", "UMT5ForTokenClassification"),
  1484. ("xlm", "XLMForTokenClassification"),
  1485. ("xlm-roberta", "XLMRobertaForTokenClassification"),
  1486. ("xlm-roberta-xl", "XLMRobertaXLForTokenClassification"),
  1487. ("xlnet", "XLNetForTokenClassification"),
  1488. ("xmod", "XmodForTokenClassification"),
  1489. ("yoso", "YosoForTokenClassification"),
  1490. ]
  1491. )
  1492. MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
  1493. [
  1494. # Model for Multiple Choice mapping
  1495. ("albert", "AlbertForMultipleChoice"),
  1496. ("bert", "BertForMultipleChoice"),
  1497. ("big_bird", "BigBirdForMultipleChoice"),
  1498. ("camembert", "CamembertForMultipleChoice"),
  1499. ("canine", "CanineForMultipleChoice"),
  1500. ("convbert", "ConvBertForMultipleChoice"),
  1501. ("data2vec-text", "Data2VecTextForMultipleChoice"),
  1502. ("deberta-v2", "DebertaV2ForMultipleChoice"),
  1503. ("distilbert", "DistilBertForMultipleChoice"),
  1504. ("electra", "ElectraForMultipleChoice"),
  1505. ("ernie", "ErnieForMultipleChoice"),
  1506. ("ernie_m", "ErnieMForMultipleChoice"),
  1507. ("flaubert", "FlaubertForMultipleChoice"),
  1508. ("fnet", "FNetForMultipleChoice"),
  1509. ("funnel", "FunnelForMultipleChoice"),
  1510. ("ibert", "IBertForMultipleChoice"),
  1511. ("longformer", "LongformerForMultipleChoice"),
  1512. ("luke", "LukeForMultipleChoice"),
  1513. ("mega", "MegaForMultipleChoice"),
  1514. ("megatron-bert", "MegatronBertForMultipleChoice"),
  1515. ("mobilebert", "MobileBertForMultipleChoice"),
  1516. ("modernbert", "ModernBertForMultipleChoice"),
  1517. ("mpnet", "MPNetForMultipleChoice"),
  1518. ("mra", "MraForMultipleChoice"),
  1519. ("nezha", "NezhaForMultipleChoice"),
  1520. ("nystromformer", "NystromformerForMultipleChoice"),
  1521. ("qdqbert", "QDQBertForMultipleChoice"),
  1522. ("rembert", "RemBertForMultipleChoice"),
  1523. ("roberta", "RobertaForMultipleChoice"),
  1524. ("roberta-prelayernorm", "RobertaPreLayerNormForMultipleChoice"),
  1525. ("roc_bert", "RoCBertForMultipleChoice"),
  1526. ("roformer", "RoFormerForMultipleChoice"),
  1527. ("squeezebert", "SqueezeBertForMultipleChoice"),
  1528. ("xlm", "XLMForMultipleChoice"),
  1529. ("xlm-roberta", "XLMRobertaForMultipleChoice"),
  1530. ("xlm-roberta-xl", "XLMRobertaXLForMultipleChoice"),
  1531. ("xlnet", "XLNetForMultipleChoice"),
  1532. ("xmod", "XmodForMultipleChoice"),
  1533. ("yoso", "YosoForMultipleChoice"),
  1534. ]
  1535. )
  1536. MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict(
  1537. [
  1538. ("bert", "BertForNextSentencePrediction"),
  1539. ("ernie", "ErnieForNextSentencePrediction"),
  1540. ("fnet", "FNetForNextSentencePrediction"),
  1541. ("megatron-bert", "MegatronBertForNextSentencePrediction"),
  1542. ("mobilebert", "MobileBertForNextSentencePrediction"),
  1543. ("nezha", "NezhaForNextSentencePrediction"),
  1544. ("qdqbert", "QDQBertForNextSentencePrediction"),
  1545. ]
  1546. )
  1547. MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
  1548. [
  1549. # Model for Audio Classification mapping
  1550. ("audio-spectrogram-transformer", "ASTForAudioClassification"),
  1551. ("data2vec-audio", "Data2VecAudioForSequenceClassification"),
  1552. ("hubert", "HubertForSequenceClassification"),
  1553. ("sew", "SEWForSequenceClassification"),
  1554. ("sew-d", "SEWDForSequenceClassification"),
  1555. ("unispeech", "UniSpeechForSequenceClassification"),
  1556. ("unispeech-sat", "UniSpeechSatForSequenceClassification"),
  1557. ("wav2vec2", "Wav2Vec2ForSequenceClassification"),
  1558. ("wav2vec2-bert", "Wav2Vec2BertForSequenceClassification"),
  1559. ("wav2vec2-conformer", "Wav2Vec2ConformerForSequenceClassification"),
  1560. ("wavlm", "WavLMForSequenceClassification"),
  1561. ("whisper", "WhisperForAudioClassification"),
  1562. ]
  1563. )
  1564. MODEL_FOR_CTC_MAPPING_NAMES = OrderedDict(
  1565. [
  1566. # Model for Connectionist temporal classification (CTC) mapping
  1567. ("data2vec-audio", "Data2VecAudioForCTC"),
  1568. ("hubert", "HubertForCTC"),
  1569. ("mctct", "MCTCTForCTC"),
  1570. ("parakeet_ctc", "ParakeetForCTC"),
  1571. ("sew", "SEWForCTC"),
  1572. ("sew-d", "SEWDForCTC"),
  1573. ("unispeech", "UniSpeechForCTC"),
  1574. ("unispeech-sat", "UniSpeechSatForCTC"),
  1575. ("wav2vec2", "Wav2Vec2ForCTC"),
  1576. ("wav2vec2-bert", "Wav2Vec2BertForCTC"),
  1577. ("wav2vec2-conformer", "Wav2Vec2ConformerForCTC"),
  1578. ("wavlm", "WavLMForCTC"),
  1579. ]
  1580. )
  1581. MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
  1582. [
  1583. # Model for Audio Classification mapping
  1584. ("data2vec-audio", "Data2VecAudioForAudioFrameClassification"),
  1585. ("unispeech-sat", "UniSpeechSatForAudioFrameClassification"),
  1586. ("wav2vec2", "Wav2Vec2ForAudioFrameClassification"),
  1587. ("wav2vec2-bert", "Wav2Vec2BertForAudioFrameClassification"),
  1588. ("wav2vec2-conformer", "Wav2Vec2ConformerForAudioFrameClassification"),
  1589. ("wavlm", "WavLMForAudioFrameClassification"),
  1590. ]
  1591. )
  1592. MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES = OrderedDict(
  1593. [
  1594. # Model for Audio Classification mapping
  1595. ("data2vec-audio", "Data2VecAudioForXVector"),
  1596. ("unispeech-sat", "UniSpeechSatForXVector"),
  1597. ("wav2vec2", "Wav2Vec2ForXVector"),
  1598. ("wav2vec2-bert", "Wav2Vec2BertForXVector"),
  1599. ("wav2vec2-conformer", "Wav2Vec2ConformerForXVector"),
  1600. ("wavlm", "WavLMForXVector"),
  1601. ]
  1602. )
  1603. MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES = OrderedDict(
  1604. [
  1605. # Model for Text-To-Spectrogram mapping
  1606. ("fastspeech2_conformer", "FastSpeech2ConformerModel"),
  1607. ("speecht5", "SpeechT5ForTextToSpeech"),
  1608. ]
  1609. )
  1610. MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES = OrderedDict(
  1611. [
  1612. # Model for Text-To-Waveform mapping
  1613. ("bark", "BarkModel"),
  1614. ("csm", "CsmForConditionalGeneration"),
  1615. ("fastspeech2_conformer", "FastSpeech2ConformerWithHifiGan"),
  1616. ("fastspeech2_conformer_with_hifigan", "FastSpeech2ConformerWithHifiGan"),
  1617. ("musicgen", "MusicgenForConditionalGeneration"),
  1618. ("musicgen_melody", "MusicgenMelodyForConditionalGeneration"),
  1619. ("qwen2_5_omni", "Qwen2_5OmniForConditionalGeneration"),
  1620. ("qwen3_omni_moe", "Qwen3OmniMoeForConditionalGeneration"),
  1621. ("seamless_m4t", "SeamlessM4TForTextToSpeech"),
  1622. ("seamless_m4t_v2", "SeamlessM4Tv2ForTextToSpeech"),
  1623. ("vits", "VitsModel"),
  1624. ]
  1625. )
  1626. MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
  1627. [
  1628. # Model for Zero Shot Image Classification mapping
  1629. ("align", "AlignModel"),
  1630. ("altclip", "AltCLIPModel"),
  1631. ("blip", "BlipModel"),
  1632. ("blip-2", "Blip2ForImageTextRetrieval"),
  1633. ("chinese_clip", "ChineseCLIPModel"),
  1634. ("clip", "CLIPModel"),
  1635. ("clipseg", "CLIPSegModel"),
  1636. ("metaclip_2", "MetaClip2Model"),
  1637. ("siglip", "SiglipModel"),
  1638. ("siglip2", "Siglip2Model"),
  1639. ]
  1640. )
  1641. MODEL_FOR_BACKBONE_MAPPING_NAMES = OrderedDict(
  1642. [
  1643. # Backbone mapping
  1644. ("beit", "BeitBackbone"),
  1645. ("bit", "BitBackbone"),
  1646. ("convnext", "ConvNextBackbone"),
  1647. ("convnextv2", "ConvNextV2Backbone"),
  1648. ("dinat", "DinatBackbone"),
  1649. ("dinov2", "Dinov2Backbone"),
  1650. ("dinov2_with_registers", "Dinov2WithRegistersBackbone"),
  1651. ("focalnet", "FocalNetBackbone"),
  1652. ("hgnet_v2", "HGNetV2Backbone"),
  1653. ("hiera", "HieraBackbone"),
  1654. ("maskformer-swin", "MaskFormerSwinBackbone"),
  1655. ("nat", "NatBackbone"),
  1656. ("pvt_v2", "PvtV2Backbone"),
  1657. ("resnet", "ResNetBackbone"),
  1658. ("rt_detr_resnet", "RTDetrResNetBackbone"),
  1659. ("swin", "SwinBackbone"),
  1660. ("swinv2", "Swinv2Backbone"),
  1661. ("textnet", "TextNetBackbone"),
  1662. ("timm_backbone", "TimmBackbone"),
  1663. ("vitdet", "VitDetBackbone"),
  1664. ("vitpose_backbone", "VitPoseBackbone"),
  1665. ]
  1666. )
  1667. MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = OrderedDict(
  1668. [
  1669. ("edgetam", "EdgeTamModel"),
  1670. ("edgetam_video", "EdgeTamModel"),
  1671. ("sam", "SamModel"),
  1672. ("sam2", "Sam2Model"),
  1673. ("sam2_video", "Sam2Model"),
  1674. ("sam_hq", "SamHQModel"),
  1675. ]
  1676. )
  1677. MODEL_FOR_KEYPOINT_DETECTION_MAPPING_NAMES = OrderedDict(
  1678. [
  1679. ("superpoint", "SuperPointForKeypointDetection"),
  1680. ]
  1681. )
  1682. MODEL_FOR_KEYPOINT_MATCHING_MAPPING_NAMES = OrderedDict(
  1683. [
  1684. ("efficientloftr", "EfficientLoFTRForKeypointMatching"),
  1685. ("lightglue", "LightGlueForKeypointMatching"),
  1686. ("superglue", "SuperGlueForKeypointMatching"),
  1687. ]
  1688. )
  1689. MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES = OrderedDict(
  1690. [
  1691. ("albert", "AlbertModel"),
  1692. ("bert", "BertModel"),
  1693. ("big_bird", "BigBirdModel"),
  1694. ("clip_text_model", "CLIPTextModel"),
  1695. ("data2vec-text", "Data2VecTextModel"),
  1696. ("deberta", "DebertaModel"),
  1697. ("deberta-v2", "DebertaV2Model"),
  1698. ("distilbert", "DistilBertModel"),
  1699. ("electra", "ElectraModel"),
  1700. ("emu3", "Emu3TextModel"),
  1701. ("flaubert", "FlaubertModel"),
  1702. ("ibert", "IBertModel"),
  1703. ("llama4", "Llama4TextModel"),
  1704. ("longformer", "LongformerModel"),
  1705. ("mllama", "MllamaTextModel"),
  1706. ("mobilebert", "MobileBertModel"),
  1707. ("mt5", "MT5EncoderModel"),
  1708. ("nystromformer", "NystromformerModel"),
  1709. ("reformer", "ReformerModel"),
  1710. ("rembert", "RemBertModel"),
  1711. ("roberta", "RobertaModel"),
  1712. ("roberta-prelayernorm", "RobertaPreLayerNormModel"),
  1713. ("roc_bert", "RoCBertModel"),
  1714. ("roformer", "RoFormerModel"),
  1715. ("squeezebert", "SqueezeBertModel"),
  1716. ("t5", "T5EncoderModel"),
  1717. ("t5gemma", "T5GemmaEncoderModel"),
  1718. ("umt5", "UMT5EncoderModel"),
  1719. ("xlm", "XLMModel"),
  1720. ("xlm-roberta", "XLMRobertaModel"),
  1721. ("xlm-roberta-xl", "XLMRobertaXLModel"),
  1722. ]
  1723. )
  1724. MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
  1725. [
  1726. ("patchtsmixer", "PatchTSMixerForTimeSeriesClassification"),
  1727. ("patchtst", "PatchTSTForClassification"),
  1728. ]
  1729. )
  1730. MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING_NAMES = OrderedDict(
  1731. [
  1732. ("patchtsmixer", "PatchTSMixerForRegression"),
  1733. ("patchtst", "PatchTSTForRegression"),
  1734. ]
  1735. )
  1736. MODEL_FOR_TIME_SERIES_PREDICTION_MAPPING_NAMES = OrderedDict(
  1737. [
  1738. ("timesfm", "TimesFmModelForPrediction"),
  1739. ]
  1740. )
  1741. MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES = OrderedDict(
  1742. [
  1743. ("swin2sr", "Swin2SRForImageSuperResolution"),
  1744. ]
  1745. )
  1746. MODEL_FOR_AUDIO_TOKENIZATION_NAMES = OrderedDict(
  1747. [
  1748. ("dac", "DacModel"),
  1749. ]
  1750. )
  1751. MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES)
  1752. MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_PRETRAINING_MAPPING_NAMES)
  1753. MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_WITH_LM_HEAD_MAPPING_NAMES)
  1754. MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
  1755. MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING = _LazyAutoMapping(
  1756. CONFIG_MAPPING_NAMES, MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING_NAMES
  1757. )
  1758. MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
  1759. CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
  1760. )
  1761. MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
  1762. CONFIG_MAPPING_NAMES, MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES
  1763. )
  1764. MODEL_FOR_IMAGE_SEGMENTATION_MAPPING = _LazyAutoMapping(
  1765. CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES
  1766. )
  1767. MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING = _LazyAutoMapping(
  1768. CONFIG_MAPPING_NAMES, MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES
  1769. )
  1770. MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING = _LazyAutoMapping(
  1771. CONFIG_MAPPING_NAMES, MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING_NAMES
  1772. )
  1773. MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING = _LazyAutoMapping(
  1774. CONFIG_MAPPING_NAMES, MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES
  1775. )
  1776. MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING = _LazyAutoMapping(
  1777. CONFIG_MAPPING_NAMES, MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES
  1778. )
  1779. MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES)
  1780. MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING = _LazyAutoMapping(
  1781. CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES
  1782. )
  1783. MODEL_FOR_RETRIEVAL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_RETRIEVAL_MAPPING_NAMES)
  1784. MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
  1785. CONFIG_MAPPING_NAMES, MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES
  1786. )
  1787. MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
  1788. CONFIG_MAPPING_NAMES, MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES
  1789. )
  1790. MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_LM_MAPPING_NAMES)
  1791. MODEL_FOR_IMAGE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_MAPPING_NAMES)
  1792. MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING = _LazyAutoMapping(
  1793. CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES
  1794. )
  1795. MODEL_FOR_OBJECT_DETECTION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES)
  1796. MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING = _LazyAutoMapping(
  1797. CONFIG_MAPPING_NAMES, MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES
  1798. )
  1799. MODEL_FOR_DEPTH_ESTIMATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES)
  1800. MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
  1801. CONFIG_MAPPING_NAMES, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
  1802. )
  1803. MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
  1804. CONFIG_MAPPING_NAMES, MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
  1805. )
  1806. MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
  1807. CONFIG_MAPPING_NAMES, MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
  1808. )
  1809. MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
  1810. CONFIG_MAPPING_NAMES, MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES
  1811. )
  1812. MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping(
  1813. CONFIG_MAPPING_NAMES, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
  1814. )
  1815. MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES)
  1816. MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping(
  1817. CONFIG_MAPPING_NAMES, MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES
  1818. )
  1819. MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping(
  1820. CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
  1821. )
  1822. MODEL_FOR_CTC_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CTC_MAPPING_NAMES)
  1823. MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES)
  1824. MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING = _LazyAutoMapping(
  1825. CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES
  1826. )
  1827. MODEL_FOR_AUDIO_XVECTOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES)
  1828. MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING = _LazyAutoMapping(
  1829. CONFIG_MAPPING_NAMES, MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES
  1830. )
  1831. MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES)
  1832. MODEL_FOR_BACKBONE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_BACKBONE_MAPPING_NAMES)
  1833. MODEL_FOR_MASK_GENERATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASK_GENERATION_MAPPING_NAMES)
  1834. MODEL_FOR_KEYPOINT_DETECTION_MAPPING = _LazyAutoMapping(
  1835. CONFIG_MAPPING_NAMES, MODEL_FOR_KEYPOINT_DETECTION_MAPPING_NAMES
  1836. )
  1837. MODEL_FOR_KEYPOINT_MATCHING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_KEYPOINT_MATCHING_MAPPING_NAMES)
  1838. MODEL_FOR_TEXT_ENCODING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES)
  1839. MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING = _LazyAutoMapping(
  1840. CONFIG_MAPPING_NAMES, MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING_NAMES
  1841. )
  1842. MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING = _LazyAutoMapping(
  1843. CONFIG_MAPPING_NAMES, MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING_NAMES
  1844. )
  1845. MODEL_FOR_TIME_SERIES_PREDICTION_MAPPING = _LazyAutoMapping(
  1846. CONFIG_MAPPING_NAMES, MODEL_FOR_TIME_SERIES_PREDICTION_MAPPING_NAMES
  1847. )
  1848. MODEL_FOR_IMAGE_TO_IMAGE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES)
  1849. MODEL_FOR_AUDIO_TOKENIZATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_TOKENIZATION_NAMES)
  1850. class AutoModelForMaskGeneration(_BaseAutoModelClass):
  1851. _model_mapping = MODEL_FOR_MASK_GENERATION_MAPPING
  1852. class AutoModelForKeypointDetection(_BaseAutoModelClass):
  1853. _model_mapping = MODEL_FOR_KEYPOINT_DETECTION_MAPPING
  1854. class AutoModelForKeypointMatching(_BaseAutoModelClass):
  1855. _model_mapping = MODEL_FOR_KEYPOINT_MATCHING_MAPPING
  1856. class AutoModelForTextEncoding(_BaseAutoModelClass):
  1857. _model_mapping = MODEL_FOR_TEXT_ENCODING_MAPPING
  1858. class AutoModelForImageToImage(_BaseAutoModelClass):
  1859. _model_mapping = MODEL_FOR_IMAGE_TO_IMAGE_MAPPING
  1860. class AutoModel(_BaseAutoModelClass):
  1861. _model_mapping = MODEL_MAPPING
  1862. AutoModel = auto_class_update(AutoModel)
  1863. class AutoModelForPreTraining(_BaseAutoModelClass):
  1864. _model_mapping = MODEL_FOR_PRETRAINING_MAPPING
  1865. AutoModelForPreTraining = auto_class_update(AutoModelForPreTraining, head_doc="pretraining")
  1866. # Private on purpose, the public class will add the deprecation warnings.
  1867. class _AutoModelWithLMHead(_BaseAutoModelClass):
  1868. _model_mapping = MODEL_WITH_LM_HEAD_MAPPING
  1869. _AutoModelWithLMHead = auto_class_update(_AutoModelWithLMHead, head_doc="language modeling")
  1870. class AutoModelForCausalLM(_BaseAutoModelClass):
  1871. _model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING
  1872. # override to give better return typehint
  1873. @classmethod
  1874. def from_pretrained(
  1875. cls: type["AutoModelForCausalLM"],
  1876. pretrained_model_name_or_path: Union[str, os.PathLike[str]],
  1877. *model_args,
  1878. **kwargs,
  1879. ) -> "_BaseModelWithGenerate":
  1880. return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
  1881. AutoModelForCausalLM = auto_class_update(AutoModelForCausalLM, head_doc="causal language modeling")
  1882. class AutoModelForMaskedLM(_BaseAutoModelClass):
  1883. _model_mapping = MODEL_FOR_MASKED_LM_MAPPING
  1884. AutoModelForMaskedLM = auto_class_update(AutoModelForMaskedLM, head_doc="masked language modeling")
  1885. class AutoModelForSeq2SeqLM(_BaseAutoModelClass):
  1886. _model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
  1887. AutoModelForSeq2SeqLM = auto_class_update(
  1888. AutoModelForSeq2SeqLM,
  1889. head_doc="sequence-to-sequence language modeling",
  1890. checkpoint_for_example="google-t5/t5-base",
  1891. )
  1892. class AutoModelForSequenceClassification(_BaseAutoModelClass):
  1893. _model_mapping = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
  1894. AutoModelForSequenceClassification = auto_class_update(
  1895. AutoModelForSequenceClassification, head_doc="sequence classification"
  1896. )
  1897. class AutoModelForQuestionAnswering(_BaseAutoModelClass):
  1898. _model_mapping = MODEL_FOR_QUESTION_ANSWERING_MAPPING
  1899. AutoModelForQuestionAnswering = auto_class_update(AutoModelForQuestionAnswering, head_doc="question answering")
  1900. class AutoModelForTableQuestionAnswering(_BaseAutoModelClass):
  1901. _model_mapping = MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING
  1902. AutoModelForTableQuestionAnswering = auto_class_update(
  1903. AutoModelForTableQuestionAnswering,
  1904. head_doc="table question answering",
  1905. checkpoint_for_example="google/tapas-base-finetuned-wtq",
  1906. )
  1907. class AutoModelForVisualQuestionAnswering(_BaseAutoModelClass):
  1908. _model_mapping = MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING
  1909. AutoModelForVisualQuestionAnswering = auto_class_update(
  1910. AutoModelForVisualQuestionAnswering,
  1911. head_doc="visual question answering",
  1912. checkpoint_for_example="dandelin/vilt-b32-finetuned-vqa",
  1913. )
  1914. class AutoModelForDocumentQuestionAnswering(_BaseAutoModelClass):
  1915. _model_mapping = MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING
  1916. AutoModelForDocumentQuestionAnswering = auto_class_update(
  1917. AutoModelForDocumentQuestionAnswering,
  1918. head_doc="document question answering",
  1919. checkpoint_for_example='impira/layoutlm-document-qa", revision="52e01b3',
  1920. )
  1921. class AutoModelForTokenClassification(_BaseAutoModelClass):
  1922. _model_mapping = MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
  1923. AutoModelForTokenClassification = auto_class_update(AutoModelForTokenClassification, head_doc="token classification")
  1924. class AutoModelForMultipleChoice(_BaseAutoModelClass):
  1925. _model_mapping = MODEL_FOR_MULTIPLE_CHOICE_MAPPING
  1926. AutoModelForMultipleChoice = auto_class_update(AutoModelForMultipleChoice, head_doc="multiple choice")
  1927. class AutoModelForNextSentencePrediction(_BaseAutoModelClass):
  1928. _model_mapping = MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING
  1929. AutoModelForNextSentencePrediction = auto_class_update(
  1930. AutoModelForNextSentencePrediction, head_doc="next sentence prediction"
  1931. )
  1932. class AutoModelForImageClassification(_BaseAutoModelClass):
  1933. _model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
  1934. AutoModelForImageClassification = auto_class_update(AutoModelForImageClassification, head_doc="image classification")
  1935. class AutoModelForZeroShotImageClassification(_BaseAutoModelClass):
  1936. _model_mapping = MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
  1937. AutoModelForZeroShotImageClassification = auto_class_update(
  1938. AutoModelForZeroShotImageClassification, head_doc="zero-shot image classification"
  1939. )
  1940. class AutoModelForImageSegmentation(_BaseAutoModelClass):
  1941. _model_mapping = MODEL_FOR_IMAGE_SEGMENTATION_MAPPING
  1942. AutoModelForImageSegmentation = auto_class_update(AutoModelForImageSegmentation, head_doc="image segmentation")
  1943. class AutoModelForSemanticSegmentation(_BaseAutoModelClass):
  1944. _model_mapping = MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING
  1945. AutoModelForSemanticSegmentation = auto_class_update(
  1946. AutoModelForSemanticSegmentation, head_doc="semantic segmentation"
  1947. )
  1948. class AutoModelForTimeSeriesPrediction(_BaseAutoModelClass):
  1949. _model_mapping = MODEL_FOR_TIME_SERIES_PREDICTION_MAPPING
  1950. AutoModelForTimeSeriesPrediction = auto_class_update(
  1951. AutoModelForTimeSeriesPrediction, head_doc="time-series prediction"
  1952. )
  1953. class AutoModelForUniversalSegmentation(_BaseAutoModelClass):
  1954. _model_mapping = MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING
  1955. AutoModelForUniversalSegmentation = auto_class_update(
  1956. AutoModelForUniversalSegmentation, head_doc="universal image segmentation"
  1957. )
  1958. class AutoModelForInstanceSegmentation(_BaseAutoModelClass):
  1959. _model_mapping = MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING
  1960. AutoModelForInstanceSegmentation = auto_class_update(
  1961. AutoModelForInstanceSegmentation, head_doc="instance segmentation"
  1962. )
  1963. class AutoModelForObjectDetection(_BaseAutoModelClass):
  1964. _model_mapping = MODEL_FOR_OBJECT_DETECTION_MAPPING
  1965. AutoModelForObjectDetection = auto_class_update(AutoModelForObjectDetection, head_doc="object detection")
  1966. class AutoModelForZeroShotObjectDetection(_BaseAutoModelClass):
  1967. _model_mapping = MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING
  1968. AutoModelForZeroShotObjectDetection = auto_class_update(
  1969. AutoModelForZeroShotObjectDetection, head_doc="zero-shot object detection"
  1970. )
  1971. class AutoModelForDepthEstimation(_BaseAutoModelClass):
  1972. _model_mapping = MODEL_FOR_DEPTH_ESTIMATION_MAPPING
  1973. AutoModelForDepthEstimation = auto_class_update(AutoModelForDepthEstimation, head_doc="depth estimation")
  1974. class AutoModelForVideoClassification(_BaseAutoModelClass):
  1975. _model_mapping = MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING
  1976. AutoModelForVideoClassification = auto_class_update(AutoModelForVideoClassification, head_doc="video classification")
  1977. # Private on purpose, the public class will add the deprecation warnings.
  1978. class _AutoModelForVision2Seq(_BaseAutoModelClass):
  1979. _model_mapping = MODEL_FOR_VISION_2_SEQ_MAPPING
  1980. _AutoModelForVision2Seq = auto_class_update(_AutoModelForVision2Seq, head_doc="vision-to-text modeling")
  1981. class AutoModelForImageTextToText(_BaseAutoModelClass):
  1982. _model_mapping = MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING
  1983. # override to give better return typehint
  1984. @classmethod
  1985. def from_pretrained(
  1986. cls: type["AutoModelForImageTextToText"],
  1987. pretrained_model_name_or_path: Union[str, os.PathLike[str]],
  1988. *model_args,
  1989. **kwargs,
  1990. ) -> "_BaseModelWithGenerate":
  1991. return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
  1992. AutoModelForImageTextToText = auto_class_update(AutoModelForImageTextToText, head_doc="image-text-to-text modeling")
  1993. class AutoModelForAudioClassification(_BaseAutoModelClass):
  1994. _model_mapping = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
  1995. AutoModelForAudioClassification = auto_class_update(AutoModelForAudioClassification, head_doc="audio classification")
  1996. class AutoModelForCTC(_BaseAutoModelClass):
  1997. _model_mapping = MODEL_FOR_CTC_MAPPING
  1998. AutoModelForCTC = auto_class_update(AutoModelForCTC, head_doc="connectionist temporal classification")
  1999. class AutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
  2000. _model_mapping = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
  2001. AutoModelForSpeechSeq2Seq = auto_class_update(
  2002. AutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeling"
  2003. )
  2004. class AutoModelForAudioFrameClassification(_BaseAutoModelClass):
  2005. _model_mapping = MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING
  2006. AutoModelForAudioFrameClassification = auto_class_update(
  2007. AutoModelForAudioFrameClassification, head_doc="audio frame (token) classification"
  2008. )
  2009. class AutoModelForAudioXVector(_BaseAutoModelClass):
  2010. _model_mapping = MODEL_FOR_AUDIO_XVECTOR_MAPPING
  2011. class AutoModelForTextToSpectrogram(_BaseAutoModelClass):
  2012. _model_mapping = MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING
  2013. class AutoModelForTextToWaveform(_BaseAutoModelClass):
  2014. _model_mapping = MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING
  2015. class AutoBackbone(_BaseAutoBackboneClass):
  2016. _model_mapping = MODEL_FOR_BACKBONE_MAPPING
  2017. AutoModelForAudioXVector = auto_class_update(AutoModelForAudioXVector, head_doc="audio retrieval via x-vector")
  2018. class AutoModelForMaskedImageModeling(_BaseAutoModelClass):
  2019. _model_mapping = MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING
  2020. AutoModelForMaskedImageModeling = auto_class_update(AutoModelForMaskedImageModeling, head_doc="masked image modeling")
  2021. class AutoModelForAudioTokenization(_BaseAutoModelClass):
  2022. _model_mapping = MODEL_FOR_AUDIO_TOKENIZATION_MAPPING
  2023. AutoModelForAudioTokenization = auto_class_update(
  2024. AutoModelForAudioTokenization, head_doc="audio tokenization through codebooks"
  2025. )
  2026. class AutoModelWithLMHead(_AutoModelWithLMHead):
  2027. @classmethod
  2028. def from_config(cls, config, **kwargs):
  2029. warnings.warn(
  2030. "The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use "
  2031. "`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and "
  2032. "`AutoModelForSeq2SeqLM` for encoder-decoder models.",
  2033. FutureWarning,
  2034. )
  2035. return super().from_config(config, **kwargs)
  2036. @classmethod
  2037. def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
  2038. warnings.warn(
  2039. "The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use "
  2040. "`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and "
  2041. "`AutoModelForSeq2SeqLM` for encoder-decoder models.",
  2042. FutureWarning,
  2043. )
  2044. return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
  2045. class AutoModelForVision2Seq(_AutoModelForVision2Seq):
  2046. @classmethod
  2047. def from_config(cls, config, **kwargs):
  2048. warnings.warn(
  2049. "The class `AutoModelForVision2Seq` is deprecated and will be removed in v5.0. Please use "
  2050. "`AutoModelForImageTextToText` instead.",
  2051. FutureWarning,
  2052. )
  2053. return super().from_config(config, **kwargs)
  2054. @classmethod
  2055. def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
  2056. warnings.warn(
  2057. "The class `AutoModelForVision2Seq` is deprecated and will be removed in v5.0. Please use "
  2058. "`AutoModelForImageTextToText` instead.",
  2059. FutureWarning,
  2060. )
  2061. return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
  2062. __all__ = [
  2063. "MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING",
  2064. "MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING",
  2065. "MODEL_FOR_AUDIO_TOKENIZATION_MAPPING",
  2066. "MODEL_FOR_AUDIO_XVECTOR_MAPPING",
  2067. "MODEL_FOR_BACKBONE_MAPPING",
  2068. "MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING",
  2069. "MODEL_FOR_CAUSAL_LM_MAPPING",
  2070. "MODEL_FOR_CTC_MAPPING",
  2071. "MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING",
  2072. "MODEL_FOR_DEPTH_ESTIMATION_MAPPING",
  2073. "MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING",
  2074. "MODEL_FOR_IMAGE_MAPPING",
  2075. "MODEL_FOR_IMAGE_SEGMENTATION_MAPPING",
  2076. "MODEL_FOR_IMAGE_TO_IMAGE_MAPPING",
  2077. "MODEL_FOR_KEYPOINT_DETECTION_MAPPING",
  2078. "MODEL_FOR_KEYPOINT_MATCHING_MAPPING",
  2079. "MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING",
  2080. "MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING",
  2081. "MODEL_FOR_MASKED_LM_MAPPING",
  2082. "MODEL_FOR_MASK_GENERATION_MAPPING",
  2083. "MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
  2084. "MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
  2085. "MODEL_FOR_OBJECT_DETECTION_MAPPING",
  2086. "MODEL_FOR_PRETRAINING_MAPPING",
  2087. "MODEL_FOR_QUESTION_ANSWERING_MAPPING",
  2088. "MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING",
  2089. "MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING",
  2090. "MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
  2091. "MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING",
  2092. "MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
  2093. "MODEL_FOR_TEXT_ENCODING_MAPPING",
  2094. "MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING",
  2095. "MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING",
  2096. "MODEL_FOR_TIME_SERIES_PREDICTION_MAPPING",
  2097. "MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
  2098. "MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING",
  2099. "MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING",
  2100. "MODEL_FOR_VISION_2_SEQ_MAPPING",
  2101. "MODEL_FOR_RETRIEVAL_MAPPING",
  2102. "MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING",
  2103. "MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING",
  2104. "MODEL_MAPPING",
  2105. "MODEL_WITH_LM_HEAD_MAPPING",
  2106. "MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING",
  2107. "MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING",
  2108. "MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING",
  2109. "MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING",
  2110. "AutoModel",
  2111. "AutoBackbone",
  2112. "AutoModelForAudioClassification",
  2113. "AutoModelForAudioFrameClassification",
  2114. "AutoModelForAudioTokenization",
  2115. "AutoModelForAudioXVector",
  2116. "AutoModelForCausalLM",
  2117. "AutoModelForCTC",
  2118. "AutoModelForDepthEstimation",
  2119. "AutoModelForImageClassification",
  2120. "AutoModelForImageSegmentation",
  2121. "AutoModelForImageToImage",
  2122. "AutoModelForInstanceSegmentation",
  2123. "AutoModelForKeypointDetection",
  2124. "AutoModelForKeypointMatching",
  2125. "AutoModelForMaskGeneration",
  2126. "AutoModelForTextEncoding",
  2127. "AutoModelForMaskedImageModeling",
  2128. "AutoModelForMaskedLM",
  2129. "AutoModelForMultipleChoice",
  2130. "AutoModelForNextSentencePrediction",
  2131. "AutoModelForObjectDetection",
  2132. "AutoModelForPreTraining",
  2133. "AutoModelForQuestionAnswering",
  2134. "AutoModelForSemanticSegmentation",
  2135. "AutoModelForSeq2SeqLM",
  2136. "AutoModelForSequenceClassification",
  2137. "AutoModelForSpeechSeq2Seq",
  2138. "AutoModelForTableQuestionAnswering",
  2139. "AutoModelForTextToSpectrogram",
  2140. "AutoModelForTextToWaveform",
  2141. "AutoModelForTimeSeriesPrediction",
  2142. "AutoModelForTokenClassification",
  2143. "AutoModelForUniversalSegmentation",
  2144. "AutoModelForVideoClassification",
  2145. "AutoModelForVision2Seq",
  2146. "AutoModelForVisualQuestionAnswering",
  2147. "AutoModelForDocumentQuestionAnswering",
  2148. "AutoModelWithLMHead",
  2149. "AutoModelForZeroShotImageClassification",
  2150. "AutoModelForZeroShotObjectDetection",
  2151. "AutoModelForImageTextToText",
  2152. ]