features.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749
  1. import os
  2. from functools import partial, reduce
  3. from typing import TYPE_CHECKING, Callable, Optional, Union
  4. import transformers
  5. from .. import PretrainedConfig, is_tf_available, is_torch_available
  6. from ..utils import TF2_WEIGHTS_NAME, WEIGHTS_NAME, logging
  7. from .config import OnnxConfig
  8. if TYPE_CHECKING:
  9. from transformers import PreTrainedModel, TFPreTrainedModel
  10. logger = logging.get_logger(__name__) # pylint: disable=invalid-name
  11. if is_torch_available():
  12. from transformers.models.auto import (
  13. AutoModel,
  14. AutoModelForCausalLM,
  15. AutoModelForImageClassification,
  16. AutoModelForImageSegmentation,
  17. AutoModelForMaskedImageModeling,
  18. AutoModelForMaskedLM,
  19. AutoModelForMultipleChoice,
  20. AutoModelForObjectDetection,
  21. AutoModelForQuestionAnswering,
  22. AutoModelForSemanticSegmentation,
  23. AutoModelForSeq2SeqLM,
  24. AutoModelForSequenceClassification,
  25. AutoModelForSpeechSeq2Seq,
  26. AutoModelForTokenClassification,
  27. AutoModelForVision2Seq,
  28. )
  29. if is_tf_available():
  30. from transformers.models.auto import (
  31. TFAutoModel,
  32. TFAutoModelForCausalLM,
  33. TFAutoModelForMaskedLM,
  34. TFAutoModelForMultipleChoice,
  35. TFAutoModelForQuestionAnswering,
  36. TFAutoModelForSemanticSegmentation,
  37. TFAutoModelForSeq2SeqLM,
  38. TFAutoModelForSequenceClassification,
  39. TFAutoModelForTokenClassification,
  40. )
  41. if not is_torch_available() and not is_tf_available():
  42. logger.warning(
  43. "The ONNX export features are only supported for PyTorch or TensorFlow. You will not be able to export models"
  44. " without one of these libraries installed."
  45. )
  46. def supported_features_mapping(
  47. *supported_features: str, onnx_config_cls: Optional[str] = None
  48. ) -> dict[str, Callable[[PretrainedConfig], OnnxConfig]]:
  49. """
  50. Generate the mapping between supported the features and their corresponding OnnxConfig for a given model.
  51. Args:
  52. *supported_features: The names of the supported features.
  53. onnx_config_cls: The OnnxConfig full name corresponding to the model.
  54. Returns:
  55. The dictionary mapping a feature to an OnnxConfig constructor.
  56. """
  57. if onnx_config_cls is None:
  58. raise ValueError("A OnnxConfig class must be provided")
  59. config_cls = transformers
  60. for attr_name in onnx_config_cls.split("."):
  61. config_cls = getattr(config_cls, attr_name)
  62. mapping = {}
  63. for feature in supported_features:
  64. if "-with-past" in feature:
  65. task = feature.replace("-with-past", "")
  66. mapping[feature] = partial(config_cls.with_past, task=task)
  67. else:
  68. mapping[feature] = partial(config_cls.from_model_config, task=feature)
  69. return mapping
  70. class FeaturesManager:
  71. _TASKS_TO_AUTOMODELS = {}
  72. _TASKS_TO_TF_AUTOMODELS = {}
  73. if is_torch_available():
  74. _TASKS_TO_AUTOMODELS = {
  75. "default": AutoModel,
  76. "masked-lm": AutoModelForMaskedLM,
  77. "causal-lm": AutoModelForCausalLM,
  78. "seq2seq-lm": AutoModelForSeq2SeqLM,
  79. "sequence-classification": AutoModelForSequenceClassification,
  80. "token-classification": AutoModelForTokenClassification,
  81. "multiple-choice": AutoModelForMultipleChoice,
  82. "object-detection": AutoModelForObjectDetection,
  83. "question-answering": AutoModelForQuestionAnswering,
  84. "image-classification": AutoModelForImageClassification,
  85. "image-segmentation": AutoModelForImageSegmentation,
  86. "masked-im": AutoModelForMaskedImageModeling,
  87. "semantic-segmentation": AutoModelForSemanticSegmentation,
  88. "vision2seq-lm": AutoModelForVision2Seq,
  89. "speech2seq-lm": AutoModelForSpeechSeq2Seq,
  90. }
  91. if is_tf_available():
  92. _TASKS_TO_TF_AUTOMODELS = {
  93. "default": TFAutoModel,
  94. "masked-lm": TFAutoModelForMaskedLM,
  95. "causal-lm": TFAutoModelForCausalLM,
  96. "seq2seq-lm": TFAutoModelForSeq2SeqLM,
  97. "sequence-classification": TFAutoModelForSequenceClassification,
  98. "token-classification": TFAutoModelForTokenClassification,
  99. "multiple-choice": TFAutoModelForMultipleChoice,
  100. "question-answering": TFAutoModelForQuestionAnswering,
  101. "semantic-segmentation": TFAutoModelForSemanticSegmentation,
  102. }
  103. # Set of model topologies we support associated to the features supported by each topology and the factory
  104. _SUPPORTED_MODEL_TYPE = {
  105. "albert": supported_features_mapping(
  106. "default",
  107. "masked-lm",
  108. "sequence-classification",
  109. "multiple-choice",
  110. "token-classification",
  111. "question-answering",
  112. onnx_config_cls="models.albert.AlbertOnnxConfig",
  113. ),
  114. "bart": supported_features_mapping(
  115. "default",
  116. "default-with-past",
  117. "causal-lm",
  118. "causal-lm-with-past",
  119. "seq2seq-lm",
  120. "seq2seq-lm-with-past",
  121. "sequence-classification",
  122. "question-answering",
  123. onnx_config_cls="models.bart.BartOnnxConfig",
  124. ),
  125. # BEiT cannot be used with the masked image modeling autoclass, so this feature is excluded here
  126. "beit": supported_features_mapping(
  127. "default", "image-classification", onnx_config_cls="models.beit.BeitOnnxConfig"
  128. ),
  129. "bert": supported_features_mapping(
  130. "default",
  131. "masked-lm",
  132. "causal-lm",
  133. "sequence-classification",
  134. "multiple-choice",
  135. "token-classification",
  136. "question-answering",
  137. onnx_config_cls="models.bert.BertOnnxConfig",
  138. ),
  139. "big-bird": supported_features_mapping(
  140. "default",
  141. "masked-lm",
  142. "causal-lm",
  143. "sequence-classification",
  144. "multiple-choice",
  145. "token-classification",
  146. "question-answering",
  147. onnx_config_cls="models.big_bird.BigBirdOnnxConfig",
  148. ),
  149. "bigbird-pegasus": supported_features_mapping(
  150. "default",
  151. "default-with-past",
  152. "causal-lm",
  153. "causal-lm-with-past",
  154. "seq2seq-lm",
  155. "seq2seq-lm-with-past",
  156. "sequence-classification",
  157. "question-answering",
  158. onnx_config_cls="models.bigbird_pegasus.BigBirdPegasusOnnxConfig",
  159. ),
  160. "blenderbot": supported_features_mapping(
  161. "default",
  162. "default-with-past",
  163. "causal-lm",
  164. "causal-lm-with-past",
  165. "seq2seq-lm",
  166. "seq2seq-lm-with-past",
  167. onnx_config_cls="models.blenderbot.BlenderbotOnnxConfig",
  168. ),
  169. "blenderbot-small": supported_features_mapping(
  170. "default",
  171. "default-with-past",
  172. "causal-lm",
  173. "causal-lm-with-past",
  174. "seq2seq-lm",
  175. "seq2seq-lm-with-past",
  176. onnx_config_cls="models.blenderbot_small.BlenderbotSmallOnnxConfig",
  177. ),
  178. "bloom": supported_features_mapping(
  179. "default",
  180. "default-with-past",
  181. "causal-lm",
  182. "causal-lm-with-past",
  183. "sequence-classification",
  184. "token-classification",
  185. onnx_config_cls="models.bloom.BloomOnnxConfig",
  186. ),
  187. "camembert": supported_features_mapping(
  188. "default",
  189. "masked-lm",
  190. "causal-lm",
  191. "sequence-classification",
  192. "multiple-choice",
  193. "token-classification",
  194. "question-answering",
  195. onnx_config_cls="models.camembert.CamembertOnnxConfig",
  196. ),
  197. "clip": supported_features_mapping(
  198. "default",
  199. onnx_config_cls="models.clip.CLIPOnnxConfig",
  200. ),
  201. "codegen": supported_features_mapping(
  202. "default",
  203. "causal-lm",
  204. onnx_config_cls="models.codegen.CodeGenOnnxConfig",
  205. ),
  206. "convbert": supported_features_mapping(
  207. "default",
  208. "masked-lm",
  209. "sequence-classification",
  210. "multiple-choice",
  211. "token-classification",
  212. "question-answering",
  213. onnx_config_cls="models.convbert.ConvBertOnnxConfig",
  214. ),
  215. "convnext": supported_features_mapping(
  216. "default",
  217. "image-classification",
  218. onnx_config_cls="models.convnext.ConvNextOnnxConfig",
  219. ),
  220. "data2vec-text": supported_features_mapping(
  221. "default",
  222. "masked-lm",
  223. "sequence-classification",
  224. "multiple-choice",
  225. "token-classification",
  226. "question-answering",
  227. onnx_config_cls="models.data2vec.Data2VecTextOnnxConfig",
  228. ),
  229. "data2vec-vision": supported_features_mapping(
  230. "default",
  231. "image-classification",
  232. # ONNX doesn't support `adaptive_avg_pool2d` yet
  233. # "semantic-segmentation",
  234. onnx_config_cls="models.data2vec.Data2VecVisionOnnxConfig",
  235. ),
  236. "deberta": supported_features_mapping(
  237. "default",
  238. "masked-lm",
  239. "sequence-classification",
  240. "token-classification",
  241. "question-answering",
  242. onnx_config_cls="models.deberta.DebertaOnnxConfig",
  243. ),
  244. "deberta-v2": supported_features_mapping(
  245. "default",
  246. "masked-lm",
  247. "sequence-classification",
  248. "multiple-choice",
  249. "token-classification",
  250. "question-answering",
  251. onnx_config_cls="models.deberta_v2.DebertaV2OnnxConfig",
  252. ),
  253. "deit": supported_features_mapping(
  254. "default", "image-classification", onnx_config_cls="models.deit.DeiTOnnxConfig"
  255. ),
  256. "detr": supported_features_mapping(
  257. "default",
  258. "object-detection",
  259. "image-segmentation",
  260. onnx_config_cls="models.detr.DetrOnnxConfig",
  261. ),
  262. "distilbert": supported_features_mapping(
  263. "default",
  264. "masked-lm",
  265. "sequence-classification",
  266. "multiple-choice",
  267. "token-classification",
  268. "question-answering",
  269. onnx_config_cls="models.distilbert.DistilBertOnnxConfig",
  270. ),
  271. "electra": supported_features_mapping(
  272. "default",
  273. "masked-lm",
  274. "causal-lm",
  275. "sequence-classification",
  276. "multiple-choice",
  277. "token-classification",
  278. "question-answering",
  279. onnx_config_cls="models.electra.ElectraOnnxConfig",
  280. ),
  281. "flaubert": supported_features_mapping(
  282. "default",
  283. "masked-lm",
  284. "causal-lm",
  285. "sequence-classification",
  286. "multiple-choice",
  287. "token-classification",
  288. "question-answering",
  289. onnx_config_cls="models.flaubert.FlaubertOnnxConfig",
  290. ),
  291. "gpt2": supported_features_mapping(
  292. "default",
  293. "default-with-past",
  294. "causal-lm",
  295. "causal-lm-with-past",
  296. "sequence-classification",
  297. "token-classification",
  298. onnx_config_cls="models.gpt2.GPT2OnnxConfig",
  299. ),
  300. "gptj": supported_features_mapping(
  301. "default",
  302. "default-with-past",
  303. "causal-lm",
  304. "causal-lm-with-past",
  305. "question-answering",
  306. "sequence-classification",
  307. onnx_config_cls="models.gptj.GPTJOnnxConfig",
  308. ),
  309. "gpt-neo": supported_features_mapping(
  310. "default",
  311. "default-with-past",
  312. "causal-lm",
  313. "causal-lm-with-past",
  314. "sequence-classification",
  315. onnx_config_cls="models.gpt_neo.GPTNeoOnnxConfig",
  316. ),
  317. "groupvit": supported_features_mapping(
  318. "default",
  319. onnx_config_cls="models.groupvit.GroupViTOnnxConfig",
  320. ),
  321. "ibert": supported_features_mapping(
  322. "default",
  323. "masked-lm",
  324. "sequence-classification",
  325. "multiple-choice",
  326. "token-classification",
  327. "question-answering",
  328. onnx_config_cls="models.ibert.IBertOnnxConfig",
  329. ),
  330. "imagegpt": supported_features_mapping(
  331. "default", "image-classification", onnx_config_cls="models.imagegpt.ImageGPTOnnxConfig"
  332. ),
  333. "layoutlm": supported_features_mapping(
  334. "default",
  335. "masked-lm",
  336. "sequence-classification",
  337. "token-classification",
  338. onnx_config_cls="models.layoutlm.LayoutLMOnnxConfig",
  339. ),
  340. "layoutlmv3": supported_features_mapping(
  341. "default",
  342. "question-answering",
  343. "sequence-classification",
  344. "token-classification",
  345. onnx_config_cls="models.layoutlmv3.LayoutLMv3OnnxConfig",
  346. ),
  347. "levit": supported_features_mapping(
  348. "default", "image-classification", onnx_config_cls="models.levit.LevitOnnxConfig"
  349. ),
  350. "longt5": supported_features_mapping(
  351. "default",
  352. "default-with-past",
  353. "seq2seq-lm",
  354. "seq2seq-lm-with-past",
  355. onnx_config_cls="models.longt5.LongT5OnnxConfig",
  356. ),
  357. "longformer": supported_features_mapping(
  358. "default",
  359. "masked-lm",
  360. "multiple-choice",
  361. "question-answering",
  362. "sequence-classification",
  363. "token-classification",
  364. onnx_config_cls="models.longformer.LongformerOnnxConfig",
  365. ),
  366. "marian": supported_features_mapping(
  367. "default",
  368. "default-with-past",
  369. "seq2seq-lm",
  370. "seq2seq-lm-with-past",
  371. "causal-lm",
  372. "causal-lm-with-past",
  373. onnx_config_cls="models.marian.MarianOnnxConfig",
  374. ),
  375. "mbart": supported_features_mapping(
  376. "default",
  377. "default-with-past",
  378. "causal-lm",
  379. "causal-lm-with-past",
  380. "seq2seq-lm",
  381. "seq2seq-lm-with-past",
  382. "sequence-classification",
  383. "question-answering",
  384. onnx_config_cls="models.mbart.MBartOnnxConfig",
  385. ),
  386. "mobilebert": supported_features_mapping(
  387. "default",
  388. "masked-lm",
  389. "sequence-classification",
  390. "multiple-choice",
  391. "token-classification",
  392. "question-answering",
  393. onnx_config_cls="models.mobilebert.MobileBertOnnxConfig",
  394. ),
  395. "mobilenet-v1": supported_features_mapping(
  396. "default",
  397. "image-classification",
  398. onnx_config_cls="models.mobilenet_v1.MobileNetV1OnnxConfig",
  399. ),
  400. "mobilenet-v2": supported_features_mapping(
  401. "default",
  402. "image-classification",
  403. onnx_config_cls="models.mobilenet_v2.MobileNetV2OnnxConfig",
  404. ),
  405. "mobilevit": supported_features_mapping(
  406. "default",
  407. "image-classification",
  408. onnx_config_cls="models.mobilevit.MobileViTOnnxConfig",
  409. ),
  410. "mt5": supported_features_mapping(
  411. "default",
  412. "default-with-past",
  413. "seq2seq-lm",
  414. "seq2seq-lm-with-past",
  415. onnx_config_cls="models.mt5.MT5OnnxConfig",
  416. ),
  417. "m2m-100": supported_features_mapping(
  418. "default",
  419. "default-with-past",
  420. "seq2seq-lm",
  421. "seq2seq-lm-with-past",
  422. onnx_config_cls="models.m2m_100.M2M100OnnxConfig",
  423. ),
  424. "owlvit": supported_features_mapping(
  425. "default",
  426. onnx_config_cls="models.owlvit.OwlViTOnnxConfig",
  427. ),
  428. "perceiver": supported_features_mapping(
  429. "image-classification",
  430. "masked-lm",
  431. "sequence-classification",
  432. onnx_config_cls="models.perceiver.PerceiverOnnxConfig",
  433. ),
  434. "poolformer": supported_features_mapping(
  435. "default", "image-classification", onnx_config_cls="models.poolformer.PoolFormerOnnxConfig"
  436. ),
  437. "rembert": supported_features_mapping(
  438. "default",
  439. "masked-lm",
  440. "causal-lm",
  441. "sequence-classification",
  442. "multiple-choice",
  443. "token-classification",
  444. "question-answering",
  445. onnx_config_cls="models.rembert.RemBertOnnxConfig",
  446. ),
  447. "resnet": supported_features_mapping(
  448. "default",
  449. "image-classification",
  450. onnx_config_cls="models.resnet.ResNetOnnxConfig",
  451. ),
  452. "roberta": supported_features_mapping(
  453. "default",
  454. "masked-lm",
  455. "causal-lm",
  456. "sequence-classification",
  457. "multiple-choice",
  458. "token-classification",
  459. "question-answering",
  460. onnx_config_cls="models.roberta.RobertaOnnxConfig",
  461. ),
  462. "roformer": supported_features_mapping(
  463. "default",
  464. "masked-lm",
  465. "causal-lm",
  466. "sequence-classification",
  467. "token-classification",
  468. "multiple-choice",
  469. "question-answering",
  470. "token-classification",
  471. onnx_config_cls="models.roformer.RoFormerOnnxConfig",
  472. ),
  473. "segformer": supported_features_mapping(
  474. "default",
  475. "image-classification",
  476. "semantic-segmentation",
  477. onnx_config_cls="models.segformer.SegformerOnnxConfig",
  478. ),
  479. "squeezebert": supported_features_mapping(
  480. "default",
  481. "masked-lm",
  482. "sequence-classification",
  483. "multiple-choice",
  484. "token-classification",
  485. "question-answering",
  486. onnx_config_cls="models.squeezebert.SqueezeBertOnnxConfig",
  487. ),
  488. "swin": supported_features_mapping(
  489. "default", "image-classification", onnx_config_cls="models.swin.SwinOnnxConfig"
  490. ),
  491. "t5": supported_features_mapping(
  492. "default",
  493. "default-with-past",
  494. "seq2seq-lm",
  495. "seq2seq-lm-with-past",
  496. onnx_config_cls="models.t5.T5OnnxConfig",
  497. ),
  498. "vision-encoder-decoder": supported_features_mapping(
  499. "vision2seq-lm", onnx_config_cls="models.vision_encoder_decoder.VisionEncoderDecoderOnnxConfig"
  500. ),
  501. "vit": supported_features_mapping(
  502. "default", "image-classification", onnx_config_cls="models.vit.ViTOnnxConfig"
  503. ),
  504. "whisper": supported_features_mapping(
  505. "default",
  506. "default-with-past",
  507. "speech2seq-lm",
  508. "speech2seq-lm-with-past",
  509. onnx_config_cls="models.whisper.WhisperOnnxConfig",
  510. ),
  511. "xlm": supported_features_mapping(
  512. "default",
  513. "masked-lm",
  514. "causal-lm",
  515. "sequence-classification",
  516. "multiple-choice",
  517. "token-classification",
  518. "question-answering",
  519. onnx_config_cls="models.xlm.XLMOnnxConfig",
  520. ),
  521. "xlm-roberta": supported_features_mapping(
  522. "default",
  523. "masked-lm",
  524. "causal-lm",
  525. "sequence-classification",
  526. "multiple-choice",
  527. "token-classification",
  528. "question-answering",
  529. onnx_config_cls="models.xlm_roberta.XLMRobertaOnnxConfig",
  530. ),
  531. "yolos": supported_features_mapping(
  532. "default",
  533. "object-detection",
  534. onnx_config_cls="models.yolos.YolosOnnxConfig",
  535. ),
  536. }
  537. AVAILABLE_FEATURES = sorted(reduce(lambda s1, s2: s1 | s2, (v.keys() for v in _SUPPORTED_MODEL_TYPE.values())))
  538. @staticmethod
  539. def get_supported_features_for_model_type(
  540. model_type: str, model_name: Optional[str] = None
  541. ) -> dict[str, Callable[[PretrainedConfig], OnnxConfig]]:
  542. """
  543. Tries to retrieve the feature -> OnnxConfig constructor map from the model type.
  544. Args:
  545. model_type (`str`):
  546. The model type to retrieve the supported features for.
  547. model_name (`str`, *optional*):
  548. The name attribute of the model object, only used for the exception message.
  549. Returns:
  550. The dictionary mapping each feature to a corresponding OnnxConfig constructor.
  551. """
  552. model_type = model_type.lower()
  553. if model_type not in FeaturesManager._SUPPORTED_MODEL_TYPE:
  554. model_type_and_model_name = f"{model_type} ({model_name})" if model_name else model_type
  555. raise KeyError(
  556. f"{model_type_and_model_name} is not supported yet. "
  557. f"Only {list(FeaturesManager._SUPPORTED_MODEL_TYPE.keys())} are supported. "
  558. f"If you want to support {model_type} please propose a PR or open up an issue."
  559. )
  560. return FeaturesManager._SUPPORTED_MODEL_TYPE[model_type]
  561. @staticmethod
  562. def feature_to_task(feature: str) -> str:
  563. return feature.replace("-with-past", "")
  564. @staticmethod
  565. def _validate_framework_choice(framework: str):
  566. """
  567. Validates if the framework requested for the export is both correct and available, otherwise throws an
  568. exception.
  569. """
  570. if framework not in ["pt", "tf"]:
  571. raise ValueError(
  572. f"Only two frameworks are supported for ONNX export: pt or tf, but {framework} was provided."
  573. )
  574. elif framework == "pt" and not is_torch_available():
  575. raise RuntimeError("Cannot export model to ONNX using PyTorch because no PyTorch package was found.")
  576. elif framework == "tf" and not is_tf_available():
  577. raise RuntimeError("Cannot export model to ONNX using TensorFlow because no TensorFlow package was found.")
  578. @staticmethod
  579. def get_model_class_for_feature(feature: str, framework: str = "pt") -> type:
  580. """
  581. Attempts to retrieve an AutoModel class from a feature name.
  582. Args:
  583. feature (`str`):
  584. The feature required.
  585. framework (`str`, *optional*, defaults to `"pt"`):
  586. The framework to use for the export.
  587. Returns:
  588. The AutoModel class corresponding to the feature.
  589. """
  590. task = FeaturesManager.feature_to_task(feature)
  591. FeaturesManager._validate_framework_choice(framework)
  592. if framework == "pt":
  593. task_to_automodel = FeaturesManager._TASKS_TO_AUTOMODELS
  594. else:
  595. task_to_automodel = FeaturesManager._TASKS_TO_TF_AUTOMODELS
  596. if task not in task_to_automodel:
  597. raise KeyError(
  598. f"Unknown task: {feature}. Possible values are {list(FeaturesManager._TASKS_TO_AUTOMODELS.values())}"
  599. )
  600. return task_to_automodel[task]
  601. @staticmethod
  602. def determine_framework(model: str, framework: Optional[str] = None) -> str:
  603. """
  604. Determines the framework to use for the export.
  605. The priority is in the following order:
  606. 1. User input via `framework`.
  607. 2. If local checkpoint is provided, use the same framework as the checkpoint.
  608. 3. Available framework in environment, with priority given to PyTorch
  609. Args:
  610. model (`str`):
  611. The name of the model to export.
  612. framework (`str`, *optional*, defaults to `None`):
  613. The framework to use for the export. See above for priority if none provided.
  614. Returns:
  615. The framework to use for the export.
  616. """
  617. if framework is not None:
  618. return framework
  619. framework_map = {"pt": "PyTorch", "tf": "TensorFlow"}
  620. exporter_map = {"pt": "torch", "tf": "tf2onnx"}
  621. if os.path.isdir(model):
  622. if os.path.isfile(os.path.join(model, WEIGHTS_NAME)):
  623. framework = "pt"
  624. elif os.path.isfile(os.path.join(model, TF2_WEIGHTS_NAME)):
  625. framework = "tf"
  626. else:
  627. raise FileNotFoundError(
  628. "Cannot determine framework from given checkpoint location."
  629. f" There should be a {WEIGHTS_NAME} for PyTorch"
  630. f" or {TF2_WEIGHTS_NAME} for TensorFlow."
  631. )
  632. logger.info(f"Local {framework_map[framework]} model found.")
  633. else:
  634. if is_torch_available():
  635. framework = "pt"
  636. elif is_tf_available():
  637. framework = "tf"
  638. else:
  639. raise OSError("Neither PyTorch nor TensorFlow found in environment. Cannot export to ONNX.")
  640. logger.info(f"Framework not requested. Using {exporter_map[framework]} to export to ONNX.")
  641. return framework
  642. @staticmethod
  643. def get_model_from_feature(
  644. feature: str, model: str, framework: Optional[str] = None, cache_dir: Optional[str] = None
  645. ) -> Union["PreTrainedModel", "TFPreTrainedModel"]:
  646. """
  647. Attempts to retrieve a model from a model's name and the feature to be enabled.
  648. Args:
  649. feature (`str`):
  650. The feature required.
  651. model (`str`):
  652. The name of the model to export.
  653. framework (`str`, *optional*, defaults to `None`):
  654. The framework to use for the export. See `FeaturesManager.determine_framework` for the priority should
  655. none be provided.
  656. Returns:
  657. The instance of the model.
  658. """
  659. framework = FeaturesManager.determine_framework(model, framework)
  660. model_class = FeaturesManager.get_model_class_for_feature(feature, framework)
  661. try:
  662. model = model_class.from_pretrained(model, cache_dir=cache_dir)
  663. except OSError:
  664. if framework == "pt":
  665. logger.info("Loading TensorFlow model in PyTorch before exporting to ONNX.")
  666. model = model_class.from_pretrained(model, from_tf=True, cache_dir=cache_dir)
  667. else:
  668. logger.info("Loading PyTorch model in TensorFlow before exporting to ONNX.")
  669. model = model_class.from_pretrained(model, from_pt=True, cache_dir=cache_dir)
  670. return model
  671. @staticmethod
  672. def check_supported_model_or_raise(
  673. model: Union["PreTrainedModel", "TFPreTrainedModel"], feature: str = "default"
  674. ) -> tuple[str, Callable]:
  675. """
  676. Check whether or not the model has the requested features.
  677. Args:
  678. model: The model to export.
  679. feature: The name of the feature to check if it is available.
  680. Returns:
  681. (str) The type of the model (OnnxConfig) The OnnxConfig instance holding the model export properties.
  682. """
  683. model_type = model.config.model_type.replace("_", "-")
  684. model_name = getattr(model, "name", "")
  685. model_features = FeaturesManager.get_supported_features_for_model_type(model_type, model_name=model_name)
  686. if feature not in model_features:
  687. raise ValueError(
  688. f"{model.config.model_type} doesn't support feature {feature}. Supported values are: {model_features}"
  689. )
  690. return model.config.model_type, FeaturesManager._SUPPORTED_MODEL_TYPE[model_type][feature]
  691. def get_config(model_type: str, feature: str) -> OnnxConfig:
  692. """
  693. Gets the OnnxConfig for a model_type and feature combination.
  694. Args:
  695. model_type (`str`):
  696. The model type to retrieve the config for.
  697. feature (`str`):
  698. The feature to retrieve the config for.
  699. Returns:
  700. `OnnxConfig`: config for the combination
  701. """
  702. return FeaturesManager._SUPPORTED_MODEL_TYPE[model_type][feature]