| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749 |
- import os
- from functools import partial, reduce
- from typing import TYPE_CHECKING, Callable, Optional, Union
- import transformers
- from .. import PretrainedConfig, is_tf_available, is_torch_available
- from ..utils import TF2_WEIGHTS_NAME, WEIGHTS_NAME, logging
- from .config import OnnxConfig
- if TYPE_CHECKING:
- from transformers import PreTrainedModel, TFPreTrainedModel
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
- if is_torch_available():
- from transformers.models.auto import (
- AutoModel,
- AutoModelForCausalLM,
- AutoModelForImageClassification,
- AutoModelForImageSegmentation,
- AutoModelForMaskedImageModeling,
- AutoModelForMaskedLM,
- AutoModelForMultipleChoice,
- AutoModelForObjectDetection,
- AutoModelForQuestionAnswering,
- AutoModelForSemanticSegmentation,
- AutoModelForSeq2SeqLM,
- AutoModelForSequenceClassification,
- AutoModelForSpeechSeq2Seq,
- AutoModelForTokenClassification,
- AutoModelForVision2Seq,
- )
- if is_tf_available():
- from transformers.models.auto import (
- TFAutoModel,
- TFAutoModelForCausalLM,
- TFAutoModelForMaskedLM,
- TFAutoModelForMultipleChoice,
- TFAutoModelForQuestionAnswering,
- TFAutoModelForSemanticSegmentation,
- TFAutoModelForSeq2SeqLM,
- TFAutoModelForSequenceClassification,
- TFAutoModelForTokenClassification,
- )
- if not is_torch_available() and not is_tf_available():
- logger.warning(
- "The ONNX export features are only supported for PyTorch or TensorFlow. You will not be able to export models"
- " without one of these libraries installed."
- )
- def supported_features_mapping(
- *supported_features: str, onnx_config_cls: Optional[str] = None
- ) -> dict[str, Callable[[PretrainedConfig], OnnxConfig]]:
- """
- Generate the mapping between supported the features and their corresponding OnnxConfig for a given model.
- Args:
- *supported_features: The names of the supported features.
- onnx_config_cls: The OnnxConfig full name corresponding to the model.
- Returns:
- The dictionary mapping a feature to an OnnxConfig constructor.
- """
- if onnx_config_cls is None:
- raise ValueError("A OnnxConfig class must be provided")
- config_cls = transformers
- for attr_name in onnx_config_cls.split("."):
- config_cls = getattr(config_cls, attr_name)
- mapping = {}
- for feature in supported_features:
- if "-with-past" in feature:
- task = feature.replace("-with-past", "")
- mapping[feature] = partial(config_cls.with_past, task=task)
- else:
- mapping[feature] = partial(config_cls.from_model_config, task=feature)
- return mapping
- class FeaturesManager:
- _TASKS_TO_AUTOMODELS = {}
- _TASKS_TO_TF_AUTOMODELS = {}
- if is_torch_available():
- _TASKS_TO_AUTOMODELS = {
- "default": AutoModel,
- "masked-lm": AutoModelForMaskedLM,
- "causal-lm": AutoModelForCausalLM,
- "seq2seq-lm": AutoModelForSeq2SeqLM,
- "sequence-classification": AutoModelForSequenceClassification,
- "token-classification": AutoModelForTokenClassification,
- "multiple-choice": AutoModelForMultipleChoice,
- "object-detection": AutoModelForObjectDetection,
- "question-answering": AutoModelForQuestionAnswering,
- "image-classification": AutoModelForImageClassification,
- "image-segmentation": AutoModelForImageSegmentation,
- "masked-im": AutoModelForMaskedImageModeling,
- "semantic-segmentation": AutoModelForSemanticSegmentation,
- "vision2seq-lm": AutoModelForVision2Seq,
- "speech2seq-lm": AutoModelForSpeechSeq2Seq,
- }
- if is_tf_available():
- _TASKS_TO_TF_AUTOMODELS = {
- "default": TFAutoModel,
- "masked-lm": TFAutoModelForMaskedLM,
- "causal-lm": TFAutoModelForCausalLM,
- "seq2seq-lm": TFAutoModelForSeq2SeqLM,
- "sequence-classification": TFAutoModelForSequenceClassification,
- "token-classification": TFAutoModelForTokenClassification,
- "multiple-choice": TFAutoModelForMultipleChoice,
- "question-answering": TFAutoModelForQuestionAnswering,
- "semantic-segmentation": TFAutoModelForSemanticSegmentation,
- }
- # Set of model topologies we support associated to the features supported by each topology and the factory
- _SUPPORTED_MODEL_TYPE = {
- "albert": supported_features_mapping(
- "default",
- "masked-lm",
- "sequence-classification",
- "multiple-choice",
- "token-classification",
- "question-answering",
- onnx_config_cls="models.albert.AlbertOnnxConfig",
- ),
- "bart": supported_features_mapping(
- "default",
- "default-with-past",
- "causal-lm",
- "causal-lm-with-past",
- "seq2seq-lm",
- "seq2seq-lm-with-past",
- "sequence-classification",
- "question-answering",
- onnx_config_cls="models.bart.BartOnnxConfig",
- ),
- # BEiT cannot be used with the masked image modeling autoclass, so this feature is excluded here
- "beit": supported_features_mapping(
- "default", "image-classification", onnx_config_cls="models.beit.BeitOnnxConfig"
- ),
- "bert": supported_features_mapping(
- "default",
- "masked-lm",
- "causal-lm",
- "sequence-classification",
- "multiple-choice",
- "token-classification",
- "question-answering",
- onnx_config_cls="models.bert.BertOnnxConfig",
- ),
- "big-bird": supported_features_mapping(
- "default",
- "masked-lm",
- "causal-lm",
- "sequence-classification",
- "multiple-choice",
- "token-classification",
- "question-answering",
- onnx_config_cls="models.big_bird.BigBirdOnnxConfig",
- ),
- "bigbird-pegasus": supported_features_mapping(
- "default",
- "default-with-past",
- "causal-lm",
- "causal-lm-with-past",
- "seq2seq-lm",
- "seq2seq-lm-with-past",
- "sequence-classification",
- "question-answering",
- onnx_config_cls="models.bigbird_pegasus.BigBirdPegasusOnnxConfig",
- ),
- "blenderbot": supported_features_mapping(
- "default",
- "default-with-past",
- "causal-lm",
- "causal-lm-with-past",
- "seq2seq-lm",
- "seq2seq-lm-with-past",
- onnx_config_cls="models.blenderbot.BlenderbotOnnxConfig",
- ),
- "blenderbot-small": supported_features_mapping(
- "default",
- "default-with-past",
- "causal-lm",
- "causal-lm-with-past",
- "seq2seq-lm",
- "seq2seq-lm-with-past",
- onnx_config_cls="models.blenderbot_small.BlenderbotSmallOnnxConfig",
- ),
- "bloom": supported_features_mapping(
- "default",
- "default-with-past",
- "causal-lm",
- "causal-lm-with-past",
- "sequence-classification",
- "token-classification",
- onnx_config_cls="models.bloom.BloomOnnxConfig",
- ),
- "camembert": supported_features_mapping(
- "default",
- "masked-lm",
- "causal-lm",
- "sequence-classification",
- "multiple-choice",
- "token-classification",
- "question-answering",
- onnx_config_cls="models.camembert.CamembertOnnxConfig",
- ),
- "clip": supported_features_mapping(
- "default",
- onnx_config_cls="models.clip.CLIPOnnxConfig",
- ),
- "codegen": supported_features_mapping(
- "default",
- "causal-lm",
- onnx_config_cls="models.codegen.CodeGenOnnxConfig",
- ),
- "convbert": supported_features_mapping(
- "default",
- "masked-lm",
- "sequence-classification",
- "multiple-choice",
- "token-classification",
- "question-answering",
- onnx_config_cls="models.convbert.ConvBertOnnxConfig",
- ),
- "convnext": supported_features_mapping(
- "default",
- "image-classification",
- onnx_config_cls="models.convnext.ConvNextOnnxConfig",
- ),
- "data2vec-text": supported_features_mapping(
- "default",
- "masked-lm",
- "sequence-classification",
- "multiple-choice",
- "token-classification",
- "question-answering",
- onnx_config_cls="models.data2vec.Data2VecTextOnnxConfig",
- ),
- "data2vec-vision": supported_features_mapping(
- "default",
- "image-classification",
- # ONNX doesn't support `adaptive_avg_pool2d` yet
- # "semantic-segmentation",
- onnx_config_cls="models.data2vec.Data2VecVisionOnnxConfig",
- ),
- "deberta": supported_features_mapping(
- "default",
- "masked-lm",
- "sequence-classification",
- "token-classification",
- "question-answering",
- onnx_config_cls="models.deberta.DebertaOnnxConfig",
- ),
- "deberta-v2": supported_features_mapping(
- "default",
- "masked-lm",
- "sequence-classification",
- "multiple-choice",
- "token-classification",
- "question-answering",
- onnx_config_cls="models.deberta_v2.DebertaV2OnnxConfig",
- ),
- "deit": supported_features_mapping(
- "default", "image-classification", onnx_config_cls="models.deit.DeiTOnnxConfig"
- ),
- "detr": supported_features_mapping(
- "default",
- "object-detection",
- "image-segmentation",
- onnx_config_cls="models.detr.DetrOnnxConfig",
- ),
- "distilbert": supported_features_mapping(
- "default",
- "masked-lm",
- "sequence-classification",
- "multiple-choice",
- "token-classification",
- "question-answering",
- onnx_config_cls="models.distilbert.DistilBertOnnxConfig",
- ),
- "electra": supported_features_mapping(
- "default",
- "masked-lm",
- "causal-lm",
- "sequence-classification",
- "multiple-choice",
- "token-classification",
- "question-answering",
- onnx_config_cls="models.electra.ElectraOnnxConfig",
- ),
- "flaubert": supported_features_mapping(
- "default",
- "masked-lm",
- "causal-lm",
- "sequence-classification",
- "multiple-choice",
- "token-classification",
- "question-answering",
- onnx_config_cls="models.flaubert.FlaubertOnnxConfig",
- ),
- "gpt2": supported_features_mapping(
- "default",
- "default-with-past",
- "causal-lm",
- "causal-lm-with-past",
- "sequence-classification",
- "token-classification",
- onnx_config_cls="models.gpt2.GPT2OnnxConfig",
- ),
- "gptj": supported_features_mapping(
- "default",
- "default-with-past",
- "causal-lm",
- "causal-lm-with-past",
- "question-answering",
- "sequence-classification",
- onnx_config_cls="models.gptj.GPTJOnnxConfig",
- ),
- "gpt-neo": supported_features_mapping(
- "default",
- "default-with-past",
- "causal-lm",
- "causal-lm-with-past",
- "sequence-classification",
- onnx_config_cls="models.gpt_neo.GPTNeoOnnxConfig",
- ),
- "groupvit": supported_features_mapping(
- "default",
- onnx_config_cls="models.groupvit.GroupViTOnnxConfig",
- ),
- "ibert": supported_features_mapping(
- "default",
- "masked-lm",
- "sequence-classification",
- "multiple-choice",
- "token-classification",
- "question-answering",
- onnx_config_cls="models.ibert.IBertOnnxConfig",
- ),
- "imagegpt": supported_features_mapping(
- "default", "image-classification", onnx_config_cls="models.imagegpt.ImageGPTOnnxConfig"
- ),
- "layoutlm": supported_features_mapping(
- "default",
- "masked-lm",
- "sequence-classification",
- "token-classification",
- onnx_config_cls="models.layoutlm.LayoutLMOnnxConfig",
- ),
- "layoutlmv3": supported_features_mapping(
- "default",
- "question-answering",
- "sequence-classification",
- "token-classification",
- onnx_config_cls="models.layoutlmv3.LayoutLMv3OnnxConfig",
- ),
- "levit": supported_features_mapping(
- "default", "image-classification", onnx_config_cls="models.levit.LevitOnnxConfig"
- ),
- "longt5": supported_features_mapping(
- "default",
- "default-with-past",
- "seq2seq-lm",
- "seq2seq-lm-with-past",
- onnx_config_cls="models.longt5.LongT5OnnxConfig",
- ),
- "longformer": supported_features_mapping(
- "default",
- "masked-lm",
- "multiple-choice",
- "question-answering",
- "sequence-classification",
- "token-classification",
- onnx_config_cls="models.longformer.LongformerOnnxConfig",
- ),
- "marian": supported_features_mapping(
- "default",
- "default-with-past",
- "seq2seq-lm",
- "seq2seq-lm-with-past",
- "causal-lm",
- "causal-lm-with-past",
- onnx_config_cls="models.marian.MarianOnnxConfig",
- ),
- "mbart": supported_features_mapping(
- "default",
- "default-with-past",
- "causal-lm",
- "causal-lm-with-past",
- "seq2seq-lm",
- "seq2seq-lm-with-past",
- "sequence-classification",
- "question-answering",
- onnx_config_cls="models.mbart.MBartOnnxConfig",
- ),
- "mobilebert": supported_features_mapping(
- "default",
- "masked-lm",
- "sequence-classification",
- "multiple-choice",
- "token-classification",
- "question-answering",
- onnx_config_cls="models.mobilebert.MobileBertOnnxConfig",
- ),
- "mobilenet-v1": supported_features_mapping(
- "default",
- "image-classification",
- onnx_config_cls="models.mobilenet_v1.MobileNetV1OnnxConfig",
- ),
- "mobilenet-v2": supported_features_mapping(
- "default",
- "image-classification",
- onnx_config_cls="models.mobilenet_v2.MobileNetV2OnnxConfig",
- ),
- "mobilevit": supported_features_mapping(
- "default",
- "image-classification",
- onnx_config_cls="models.mobilevit.MobileViTOnnxConfig",
- ),
- "mt5": supported_features_mapping(
- "default",
- "default-with-past",
- "seq2seq-lm",
- "seq2seq-lm-with-past",
- onnx_config_cls="models.mt5.MT5OnnxConfig",
- ),
- "m2m-100": supported_features_mapping(
- "default",
- "default-with-past",
- "seq2seq-lm",
- "seq2seq-lm-with-past",
- onnx_config_cls="models.m2m_100.M2M100OnnxConfig",
- ),
- "owlvit": supported_features_mapping(
- "default",
- onnx_config_cls="models.owlvit.OwlViTOnnxConfig",
- ),
- "perceiver": supported_features_mapping(
- "image-classification",
- "masked-lm",
- "sequence-classification",
- onnx_config_cls="models.perceiver.PerceiverOnnxConfig",
- ),
- "poolformer": supported_features_mapping(
- "default", "image-classification", onnx_config_cls="models.poolformer.PoolFormerOnnxConfig"
- ),
- "rembert": supported_features_mapping(
- "default",
- "masked-lm",
- "causal-lm",
- "sequence-classification",
- "multiple-choice",
- "token-classification",
- "question-answering",
- onnx_config_cls="models.rembert.RemBertOnnxConfig",
- ),
- "resnet": supported_features_mapping(
- "default",
- "image-classification",
- onnx_config_cls="models.resnet.ResNetOnnxConfig",
- ),
- "roberta": supported_features_mapping(
- "default",
- "masked-lm",
- "causal-lm",
- "sequence-classification",
- "multiple-choice",
- "token-classification",
- "question-answering",
- onnx_config_cls="models.roberta.RobertaOnnxConfig",
- ),
- "roformer": supported_features_mapping(
- "default",
- "masked-lm",
- "causal-lm",
- "sequence-classification",
- "token-classification",
- "multiple-choice",
- "question-answering",
- "token-classification",
- onnx_config_cls="models.roformer.RoFormerOnnxConfig",
- ),
- "segformer": supported_features_mapping(
- "default",
- "image-classification",
- "semantic-segmentation",
- onnx_config_cls="models.segformer.SegformerOnnxConfig",
- ),
- "squeezebert": supported_features_mapping(
- "default",
- "masked-lm",
- "sequence-classification",
- "multiple-choice",
- "token-classification",
- "question-answering",
- onnx_config_cls="models.squeezebert.SqueezeBertOnnxConfig",
- ),
- "swin": supported_features_mapping(
- "default", "image-classification", onnx_config_cls="models.swin.SwinOnnxConfig"
- ),
- "t5": supported_features_mapping(
- "default",
- "default-with-past",
- "seq2seq-lm",
- "seq2seq-lm-with-past",
- onnx_config_cls="models.t5.T5OnnxConfig",
- ),
- "vision-encoder-decoder": supported_features_mapping(
- "vision2seq-lm", onnx_config_cls="models.vision_encoder_decoder.VisionEncoderDecoderOnnxConfig"
- ),
- "vit": supported_features_mapping(
- "default", "image-classification", onnx_config_cls="models.vit.ViTOnnxConfig"
- ),
- "whisper": supported_features_mapping(
- "default",
- "default-with-past",
- "speech2seq-lm",
- "speech2seq-lm-with-past",
- onnx_config_cls="models.whisper.WhisperOnnxConfig",
- ),
- "xlm": supported_features_mapping(
- "default",
- "masked-lm",
- "causal-lm",
- "sequence-classification",
- "multiple-choice",
- "token-classification",
- "question-answering",
- onnx_config_cls="models.xlm.XLMOnnxConfig",
- ),
- "xlm-roberta": supported_features_mapping(
- "default",
- "masked-lm",
- "causal-lm",
- "sequence-classification",
- "multiple-choice",
- "token-classification",
- "question-answering",
- onnx_config_cls="models.xlm_roberta.XLMRobertaOnnxConfig",
- ),
- "yolos": supported_features_mapping(
- "default",
- "object-detection",
- onnx_config_cls="models.yolos.YolosOnnxConfig",
- ),
- }
- AVAILABLE_FEATURES = sorted(reduce(lambda s1, s2: s1 | s2, (v.keys() for v in _SUPPORTED_MODEL_TYPE.values())))
- @staticmethod
- def get_supported_features_for_model_type(
- model_type: str, model_name: Optional[str] = None
- ) -> dict[str, Callable[[PretrainedConfig], OnnxConfig]]:
- """
- Tries to retrieve the feature -> OnnxConfig constructor map from the model type.
- Args:
- model_type (`str`):
- The model type to retrieve the supported features for.
- model_name (`str`, *optional*):
- The name attribute of the model object, only used for the exception message.
- Returns:
- The dictionary mapping each feature to a corresponding OnnxConfig constructor.
- """
- model_type = model_type.lower()
- if model_type not in FeaturesManager._SUPPORTED_MODEL_TYPE:
- model_type_and_model_name = f"{model_type} ({model_name})" if model_name else model_type
- raise KeyError(
- f"{model_type_and_model_name} is not supported yet. "
- f"Only {list(FeaturesManager._SUPPORTED_MODEL_TYPE.keys())} are supported. "
- f"If you want to support {model_type} please propose a PR or open up an issue."
- )
- return FeaturesManager._SUPPORTED_MODEL_TYPE[model_type]
- @staticmethod
- def feature_to_task(feature: str) -> str:
- return feature.replace("-with-past", "")
- @staticmethod
- def _validate_framework_choice(framework: str):
- """
- Validates if the framework requested for the export is both correct and available, otherwise throws an
- exception.
- """
- if framework not in ["pt", "tf"]:
- raise ValueError(
- f"Only two frameworks are supported for ONNX export: pt or tf, but {framework} was provided."
- )
- elif framework == "pt" and not is_torch_available():
- raise RuntimeError("Cannot export model to ONNX using PyTorch because no PyTorch package was found.")
- elif framework == "tf" and not is_tf_available():
- raise RuntimeError("Cannot export model to ONNX using TensorFlow because no TensorFlow package was found.")
- @staticmethod
- def get_model_class_for_feature(feature: str, framework: str = "pt") -> type:
- """
- Attempts to retrieve an AutoModel class from a feature name.
- Args:
- feature (`str`):
- The feature required.
- framework (`str`, *optional*, defaults to `"pt"`):
- The framework to use for the export.
- Returns:
- The AutoModel class corresponding to the feature.
- """
- task = FeaturesManager.feature_to_task(feature)
- FeaturesManager._validate_framework_choice(framework)
- if framework == "pt":
- task_to_automodel = FeaturesManager._TASKS_TO_AUTOMODELS
- else:
- task_to_automodel = FeaturesManager._TASKS_TO_TF_AUTOMODELS
- if task not in task_to_automodel:
- raise KeyError(
- f"Unknown task: {feature}. Possible values are {list(FeaturesManager._TASKS_TO_AUTOMODELS.values())}"
- )
- return task_to_automodel[task]
- @staticmethod
- def determine_framework(model: str, framework: Optional[str] = None) -> str:
- """
- Determines the framework to use for the export.
- The priority is in the following order:
- 1. User input via `framework`.
- 2. If local checkpoint is provided, use the same framework as the checkpoint.
- 3. Available framework in environment, with priority given to PyTorch
- Args:
- model (`str`):
- The name of the model to export.
- framework (`str`, *optional*, defaults to `None`):
- The framework to use for the export. See above for priority if none provided.
- Returns:
- The framework to use for the export.
- """
- if framework is not None:
- return framework
- framework_map = {"pt": "PyTorch", "tf": "TensorFlow"}
- exporter_map = {"pt": "torch", "tf": "tf2onnx"}
- if os.path.isdir(model):
- if os.path.isfile(os.path.join(model, WEIGHTS_NAME)):
- framework = "pt"
- elif os.path.isfile(os.path.join(model, TF2_WEIGHTS_NAME)):
- framework = "tf"
- else:
- raise FileNotFoundError(
- "Cannot determine framework from given checkpoint location."
- f" There should be a {WEIGHTS_NAME} for PyTorch"
- f" or {TF2_WEIGHTS_NAME} for TensorFlow."
- )
- logger.info(f"Local {framework_map[framework]} model found.")
- else:
- if is_torch_available():
- framework = "pt"
- elif is_tf_available():
- framework = "tf"
- else:
- raise OSError("Neither PyTorch nor TensorFlow found in environment. Cannot export to ONNX.")
- logger.info(f"Framework not requested. Using {exporter_map[framework]} to export to ONNX.")
- return framework
- @staticmethod
- def get_model_from_feature(
- feature: str, model: str, framework: Optional[str] = None, cache_dir: Optional[str] = None
- ) -> Union["PreTrainedModel", "TFPreTrainedModel"]:
- """
- Attempts to retrieve a model from a model's name and the feature to be enabled.
- Args:
- feature (`str`):
- The feature required.
- model (`str`):
- The name of the model to export.
- framework (`str`, *optional*, defaults to `None`):
- The framework to use for the export. See `FeaturesManager.determine_framework` for the priority should
- none be provided.
- Returns:
- The instance of the model.
- """
- framework = FeaturesManager.determine_framework(model, framework)
- model_class = FeaturesManager.get_model_class_for_feature(feature, framework)
- try:
- model = model_class.from_pretrained(model, cache_dir=cache_dir)
- except OSError:
- if framework == "pt":
- logger.info("Loading TensorFlow model in PyTorch before exporting to ONNX.")
- model = model_class.from_pretrained(model, from_tf=True, cache_dir=cache_dir)
- else:
- logger.info("Loading PyTorch model in TensorFlow before exporting to ONNX.")
- model = model_class.from_pretrained(model, from_pt=True, cache_dir=cache_dir)
- return model
- @staticmethod
- def check_supported_model_or_raise(
- model: Union["PreTrainedModel", "TFPreTrainedModel"], feature: str = "default"
- ) -> tuple[str, Callable]:
- """
- Check whether or not the model has the requested features.
- Args:
- model: The model to export.
- feature: The name of the feature to check if it is available.
- Returns:
- (str) The type of the model (OnnxConfig) The OnnxConfig instance holding the model export properties.
- """
- model_type = model.config.model_type.replace("_", "-")
- model_name = getattr(model, "name", "")
- model_features = FeaturesManager.get_supported_features_for_model_type(model_type, model_name=model_name)
- if feature not in model_features:
- raise ValueError(
- f"{model.config.model_type} doesn't support feature {feature}. Supported values are: {model_features}"
- )
- return model.config.model_type, FeaturesManager._SUPPORTED_MODEL_TYPE[model_type][feature]
- def get_config(model_type: str, feature: str) -> OnnxConfig:
- """
- Gets the OnnxConfig for a model_type and feature combination.
- Args:
- model_type (`str`):
- The model type to retrieve the config for.
- feature (`str`):
- The feature to retrieve the config for.
- Returns:
- `OnnxConfig`: config for the combination
- """
- return FeaturesManager._SUPPORTED_MODEL_TYPE[model_type][feature]
|