| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274 |
- # coding=utf-8
- # Copyright 2021 The Google Flax Team Authors and The HuggingFace Inc. team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import gc
- import json
- import os
- import warnings
- from functools import partial
- from pickle import UnpicklingError
- from typing import Any, Optional, Union
- import flax.linen as nn
- import jax
- import jax.numpy as jnp
- import msgpack.exceptions
- from flax.core.frozen_dict import FrozenDict, unfreeze
- from flax.serialization import from_bytes, to_bytes
- from flax.traverse_util import flatten_dict, unflatten_dict
- from jax.random import PRNGKey
- from .configuration_utils import PretrainedConfig
- from .dynamic_module_utils import custom_object_save
- from .generation import FlaxGenerationMixin, GenerationConfig
- from .modeling_flax_pytorch_utils import load_pytorch_checkpoint_in_flax_state_dict
- from .utils import (
- FLAX_WEIGHTS_INDEX_NAME,
- FLAX_WEIGHTS_NAME,
- SAFE_WEIGHTS_INDEX_NAME,
- SAFE_WEIGHTS_NAME,
- WEIGHTS_INDEX_NAME,
- WEIGHTS_NAME,
- PushToHubMixin,
- add_code_sample_docstrings,
- add_start_docstrings_to_model_forward,
- cached_file,
- copy_func,
- download_url,
- has_file,
- is_offline_mode,
- is_remote_url,
- logging,
- replace_return_docstrings,
- )
- from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
- from .utils.import_utils import is_safetensors_available
- if is_safetensors_available():
- from safetensors import safe_open
- from safetensors.flax import load_file as safe_load_file
- from safetensors.flax import save_file as safe_save_file
- logger = logging.get_logger(__name__)
- def quick_gelu(x):
- return x * jax.nn.sigmoid(1.702 * x)
- ACT2FN = {
- "gelu": partial(nn.gelu, approximate=False),
- "relu": nn.relu,
- "silu": nn.swish,
- "swish": nn.swish,
- "gelu_new": partial(nn.gelu, approximate=True),
- "quick_gelu": quick_gelu,
- "gelu_pytorch_tanh": partial(nn.gelu, approximate=True),
- "tanh": nn.tanh,
- }
- def flax_shard_checkpoint(params, max_shard_size="10GB"):
- """
- Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
- given size. The sub-checkpoints are determined by iterating through the `state_dict` in the order of its keys, so
- there is no optimization made to make each sub-checkpoint as close as possible to the maximum size passed. For
- example, if the limit is 10GB and we have weights of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as
- [6GB], [6+2GB], [6+2+2GB] and not [6+2+2GB], [6+2GB], [6GB].
- <Tip warning={true}>
- If one of the model's weight is bigger that `max_shard_size`, it will end up in its own sub-checkpoint which will
- have a size greater than `max_shard_size`.
- </Tip>
- Args:
- params (`Union[Dict, FrozenDict]`): A `PyTree` of model parameters.
- max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
- The maximum size of each sub-checkpoint. If expressed as a string, needs to be digits followed by a unit
- (like `"5MB"`).
- """
- max_shard_size = convert_file_size_to_int(max_shard_size)
- sharded_state_dicts = []
- current_block = {}
- current_block_size = 0
- total_size = 0
- # flatten the weights to chunk
- weights = flatten_dict(params, sep="/")
- for item in weights:
- weight_size = weights[item].size * weights[item].dtype.itemsize
- # If this weight is going to tip up over the maximal size, we split.
- if current_block_size + weight_size > max_shard_size:
- sharded_state_dicts.append(current_block)
- current_block = {}
- current_block_size = 0
- current_block[item] = weights[item]
- current_block_size += weight_size
- total_size += weight_size
- # Add the last block
- sharded_state_dicts.append(current_block)
- # If we only have one shard, we return it
- if len(sharded_state_dicts) == 1:
- return {FLAX_WEIGHTS_NAME: sharded_state_dicts[0]}, None
- # Otherwise, let's build the index
- weight_map = {}
- shards = {}
- for idx, shard in enumerate(sharded_state_dicts):
- shard_file = FLAX_WEIGHTS_NAME.replace(".msgpack", f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.msgpack")
- shards[shard_file] = shard
- for weight_name in shard:
- weight_map[weight_name] = shard_file
- # Add the metadata
- metadata = {"total_size": total_size}
- index = {"metadata": metadata, "weight_map": weight_map}
- return shards, index
- class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
- r"""
- Base class for all models.
- [`FlaxPreTrainedModel`] takes care of storing the configuration of the models and handles methods for loading,
- downloading and saving models.
- Class attributes (overridden by derived classes):
- - **config_class** ([`PretrainedConfig`]) -- A subclass of [`PretrainedConfig`] to use as configuration class
- for this model architecture.
- - **base_model_prefix** (`str`) -- A string indicating the attribute associated to the base model in derived
- classes of the same architecture adding modules on top of the base model.
- - **main_input_name** (`str`) -- The name of the principal input to the model (often `input_ids` for NLP
- models, `pixel_values` for vision models and `input_values` for speech models).
- """
- config_class = None
- base_model_prefix = ""
- main_input_name = "input_ids"
- _auto_class = None
- _missing_keys = set()
- def __init__(
- self,
- config: PretrainedConfig,
- module: nn.Module,
- input_shape: tuple = (1, 1),
- seed: int = 0,
- dtype: jnp.dtype = jnp.float32,
- _do_init: bool = True,
- ):
- logger.warning_once(
- "TensorFlow and JAX classes are deprecated and will be removed in Transformers v5. We "
- "recommend migrating to PyTorch classes or pinning your version of Transformers."
- )
- if config is None:
- raise ValueError("config cannot be None")
- if module is None:
- raise ValueError("module cannot be None")
- # Those are private to be exposed as typed property on derived classes.
- self._config = config
- self._module = module
- # Those are public as their type is generic to every derived classes.
- self.key = PRNGKey(seed)
- self.dtype = dtype
- self.input_shape = input_shape
- self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
- # To check if the model was initialized automatically.
- self._is_initialized = _do_init
- if _do_init:
- # randomly initialized parameters
- random_params = self.init_weights(self.key, input_shape)
- params_shape_tree = jax.eval_shape(lambda params: params, random_params)
- else:
- init_fn = partial(self.init_weights, input_shape=input_shape)
- params_shape_tree = jax.eval_shape(init_fn, self.key)
- logger.info(
- "Model weights are not initialized as `_do_init` is set to `False`. "
- f"Make sure to call `{self.__class__.__name__}.init_weights` manually to initialize the weights."
- )
- # get the shape of the parameters
- self._params_shape_tree = params_shape_tree
- # save required_params as set
- self._required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys())
- # initialize the parameters
- if _do_init:
- self.params = random_params
- def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> dict:
- raise NotImplementedError(f"init method has to be implemented for {self}")
- def enable_gradient_checkpointing(self):
- raise NotImplementedError(f"gradient checkpointing method has to be implemented for {self}")
- @classmethod
- def _from_config(cls, config, **kwargs):
- """
- All context managers that the model should be initialized under go here.
- """
- return cls(config, **kwargs)
- @property
- def framework(self) -> str:
- """
- :str: Identifies that this is a Flax model.
- """
- return "flax"
- @property
- def config(self) -> PretrainedConfig:
- return self._config
- @property
- def module(self) -> nn.Module:
- return self._module
- @property
- def params(self) -> Union[dict, FrozenDict]:
- if not self._is_initialized:
- raise ValueError(
- "`params` cannot be accessed from model when the model is created with `_do_init=False`. "
- "You must call `init_weights` manually and store the params outside of the model and "
- "pass it explicitly where needed."
- )
- return self._params
- @property
- def required_params(self) -> set:
- return self._required_params
- @property
- def params_shape_tree(self) -> dict:
- return self._params_shape_tree
- @params.setter
- def params(self, params: Union[dict, FrozenDict]):
- # don't set params if the model is not initialized
- if not self._is_initialized:
- raise ValueError(
- "`params` cannot be set from model when the model is created with `_do_init=False`. "
- "You store the params outside of the model."
- )
- if isinstance(params, FrozenDict):
- params = unfreeze(params)
- param_keys = set(flatten_dict(params).keys())
- if len(self.required_params - param_keys) > 0:
- raise ValueError(
- "Some parameters are missing. Make sure that `params` include the following "
- f"parameters {self.required_params - param_keys}"
- )
- self._params = params
- def _cast_floating_to(self, params: Union[dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any:
- """
- Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`.
- """
- # taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27
- def conditional_cast(param):
- if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating):
- param = param.astype(dtype)
- return param
- if mask is None:
- return jax.tree_util.tree_map(conditional_cast, params)
- flat_params = flatten_dict(params)
- flat_mask, _ = jax.tree_util.tree_flatten(mask)
- for masked, key in zip(flat_mask, sorted(flat_params.keys())):
- if masked:
- flat_params[key] = conditional_cast(flat_params[key])
- return unflatten_dict(flat_params)
- def to_bf16(self, params: Union[dict, FrozenDict], mask: Any = None):
- r"""
- Cast the floating-point `params` to `jax.numpy.bfloat16`. This returns a new `params` tree and does not cast
- the `params` in place.
- This method can be used on TPU to explicitly convert the model parameters to bfloat16 precision to do full
- half-precision training or to save weights in bfloat16 for inference in order to save memory and improve speed.
- Arguments:
- params (`Union[Dict, FrozenDict]`):
- A `PyTree` of model parameters.
- mask (`Union[Dict, FrozenDict]`):
- A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params
- you want to cast, and should be `False` for those you want to skip.
- Examples:
- ```python
- >>> from transformers import FlaxBertModel
- >>> # load model
- >>> model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased")
- >>> # By default, the model parameters will be in fp32 precision, to cast these to bfloat16 precision
- >>> model.params = model.to_bf16(model.params)
- >>> # If you want don't want to cast certain parameters (for example layer norm bias and scale)
- >>> # then pass the mask as follows
- >>> from flax import traverse_util
- >>> model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased")
- >>> flat_params = traverse_util.flatten_dict(model.params)
- >>> mask = {
- ... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale"))
- ... for path in flat_params
- ... }
- >>> mask = traverse_util.unflatten_dict(mask)
- >>> model.params = model.to_bf16(model.params, mask)
- ```"""
- return self._cast_floating_to(params, jnp.bfloat16, mask)
- def to_fp32(self, params: Union[dict, FrozenDict], mask: Any = None):
- r"""
- Cast the floating-point `params` to `jax.numpy.float32`. This method can be used to explicitly convert the
- model parameters to fp32 precision. This returns a new `params` tree and does not cast the `params` in place.
- Arguments:
- params (`Union[Dict, FrozenDict]`):
- A `PyTree` of model parameters.
- mask (`Union[Dict, FrozenDict]`):
- A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params
- you want to cast, and should be `False` for those you want to skip
- Examples:
- ```python
- >>> from transformers import FlaxBertModel
- >>> # Download model and configuration from huggingface.co
- >>> model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased")
- >>> # By default, the model params will be in fp32, to illustrate the use of this method,
- >>> # we'll first cast to fp16 and back to fp32
- >>> model.params = model.to_f16(model.params)
- >>> # now cast back to fp32
- >>> model.params = model.to_fp32(model.params)
- ```"""
- return self._cast_floating_to(params, jnp.float32, mask)
- def to_fp16(self, params: Union[dict, FrozenDict], mask: Any = None):
- r"""
- Cast the floating-point `params` to `jax.numpy.float16`. This returns a new `params` tree and does not cast the
- `params` in place.
- This method can be used on GPU to explicitly convert the model parameters to float16 precision to do full
- half-precision training or to save weights in float16 for inference in order to save memory and improve speed.
- Arguments:
- params (`Union[Dict, FrozenDict]`):
- A `PyTree` of model parameters.
- mask (`Union[Dict, FrozenDict]`):
- A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params
- you want to cast, and should be `False` for those you want to skip
- Examples:
- ```python
- >>> from transformers import FlaxBertModel
- >>> # load model
- >>> model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased")
- >>> # By default, the model params will be in fp32, to cast these to float16
- >>> model.params = model.to_fp16(model.params)
- >>> # If you want don't want to cast certain parameters (for example layer norm bias and scale)
- >>> # then pass the mask as follows
- >>> from flax import traverse_util
- >>> model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased")
- >>> flat_params = traverse_util.flatten_dict(model.params)
- >>> mask = {
- ... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale"))
- ... for path in flat_params
- ... }
- >>> mask = traverse_util.unflatten_dict(mask)
- >>> model.params = model.to_fp16(model.params, mask)
- ```"""
- return self._cast_floating_to(params, jnp.float16, mask)
- @classmethod
- def load_flax_weights(cls, resolved_archive_file):
- try:
- if resolved_archive_file.endswith(".safetensors"):
- state = safe_load_file(resolved_archive_file)
- state = unflatten_dict(state, sep=".")
- else:
- with open(resolved_archive_file, "rb") as state_f:
- state = from_bytes(cls, state_f.read())
- except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
- try:
- with open(resolved_archive_file) as f:
- if f.read().startswith("version"):
- raise OSError(
- "You seem to have cloned a repository without having git-lfs installed. Please"
- " install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
- " folder you cloned."
- )
- else:
- raise ValueError from e
- except (UnicodeDecodeError, ValueError):
- raise OSError(f"Unable to convert {resolved_archive_file} to Flax deserializable object. ")
- return state
- @classmethod
- def load_flax_sharded_weights(cls, shard_files):
- """
- This is the same as [`flax.serialization.from_bytes`]
- (https:lax.readthedocs.io/en/latest/_modules/flax/serialization.html#from_bytes) but for a sharded checkpoint.
- This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being
- loaded in the model.
- Args:
- shard_files (`list[str]`:
- The list of shard files to load.
- Returns:
- `Dict`: A nested dictionary of the model parameters, in the expected format for flax models : `{'model':
- {'params': {'...'}}}`.
- """
- # Load the index
- state_sharded_dict = {}
- for shard_file in shard_files:
- # load using msgpack utils
- try:
- with open(shard_file, "rb") as state_f:
- state = from_bytes(cls, state_f.read())
- except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
- with open(shard_file) as f:
- if f.read().startswith("version"):
- raise OSError(
- "You seem to have cloned a repository without having git-lfs installed. Please"
- " install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
- " folder you cloned."
- )
- else:
- raise ValueError from e
- except (UnicodeDecodeError, ValueError):
- raise OSError(f"Unable to convert {shard_file} to Flax deserializable object. ")
- state = flatten_dict(state, sep="/")
- state_sharded_dict.update(state)
- del state
- gc.collect()
- # the state dict is unflattened to the match the format of model.params
- return unflatten_dict(state_sharded_dict, sep="/")
- @classmethod
- def can_generate(cls) -> bool:
- """
- Returns whether this model can generate sequences with `.generate()`. Returns:
- `bool`: Whether this model can generate sequences with `.generate()`.
- """
- # Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation.
- # Alternatively, the model can also have a custom `generate` function.
- if "GenerationMixin" in str(cls.prepare_inputs_for_generation) and "GenerationMixin" in str(cls.generate):
- return False
- return True
- @classmethod
- def from_pretrained(
- cls,
- pretrained_model_name_or_path: Union[str, os.PathLike],
- dtype: jnp.dtype = jnp.float32,
- *model_args,
- config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
- cache_dir: Optional[Union[str, os.PathLike]] = None,
- ignore_mismatched_sizes: bool = False,
- force_download: bool = False,
- local_files_only: bool = False,
- token: Optional[Union[str, bool]] = None,
- revision: str = "main",
- **kwargs,
- ):
- r"""
- Instantiate a pretrained flax model from a pre-trained model configuration.
- The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
- pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
- task.
- The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
- weights are discarded.
- Parameters:
- pretrained_model_name_or_path (`str` or `os.PathLike`):
- Can be either:
- - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
- - A path to a *directory* containing model weights saved using
- [`~FlaxPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
- - A path or url to a *pt index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In this case,
- `from_pt` should be set to `True`.
- dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
- The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
- `jax.numpy.bfloat16` (on TPUs).
- This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
- specified all the computation will be performed with the given `dtype`.
- **Note that this only specifies the dtype of the computation and does not influence the dtype of model
- parameters.**
- If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
- [`~FlaxPreTrainedModel.to_bf16`].
- model_args (sequence of positional arguments, *optional*):
- All remaining positional arguments will be passed to the underlying model's `__init__` method.
- config (`Union[PretrainedConfig, str, os.PathLike]`, *optional*):
- Can be either:
- - an instance of a class derived from [`PretrainedConfig`],
- - a string or path valid as input to [`~PretrainedConfig.from_pretrained`].
- Configuration for the model to use instead of an automatically loaded configuration. Configuration can
- be automatically loaded when:
- - The model is a model provided by the library (loaded with the *model id* string of a pretrained
- model).
- - The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the
- save directory.
- - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a
- configuration JSON file named *config.json* is found in the directory.
- cache_dir (`Union[str, os.PathLike]`, *optional*):
- Path to a directory in which a downloaded pretrained model configuration should be cached if the
- standard cache should not be used.
- from_pt (`bool`, *optional*, defaults to `False`):
- Load the model weights from a PyTorch checkpoint save file (see docstring of
- `pretrained_model_name_or_path` argument).
- ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`):
- Whether or not to raise an error if some of the weights from the checkpoint do not have the same size
- as the weights of the model (if for instance, you are instantiating a model with 10 labels from a
- checkpoint with 3 labels).
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
- cached versions if they exist.
- resume_download:
- Deprecated and ignored. All downloads are now resumed by default when possible.
- Will be removed in v5 of Transformers.
- proxies (`dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
- local_files_only(`bool`, *optional*, defaults to `False`):
- Whether or not to only look at local files (i.e., do not try to download the model).
- token (`str` or `bool`, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
- the token generated when running `hf auth login` (stored in `~/.huggingface`).
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
- git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
- identifier allowed by git.
- <Tip>
- To test a pull request you made on the Hub, you can pass `revision="refs/pr/<pr_number>"`.
- </Tip>
- subfolder (`str`, *optional*, defaults to `""`):
- In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
- specify the folder name here.
- kwargs (remaining dictionary of keyword arguments, *optional*):
- Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
- `output_attentions=True`). Behaves differently depending on whether a `config` is provided or
- automatically loaded:
- - If a configuration is provided with `config`, `**kwargs` will be directly passed to the
- underlying model's `__init__` method (we assume all relevant updates to the configuration have
- already been done)
- - If a configuration is not provided, `kwargs` will be first passed to the configuration class
- initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that
- corresponds to a configuration attribute will be used to override said attribute with the
- supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute
- will be passed to the underlying model's `__init__` function.
- Examples:
- ```python
- >>> from transformers import BertConfig, FlaxBertModel
- >>> # Download model and configuration from huggingface.co and cache.
- >>> model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased")
- >>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable).
- >>> model = FlaxBertModel.from_pretrained("./test/saved_model/")
- >>> # Loading from a PyTorch checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable).
- >>> config = BertConfig.from_json_file("./pt_model/config.json")
- >>> model = FlaxBertModel.from_pretrained("./pt_model/pytorch_model.bin", from_pt=True, config=config)
- ```"""
- from_pt = kwargs.pop("from_pt", False)
- resume_download = kwargs.pop("resume_download", None)
- proxies = kwargs.pop("proxies", None)
- use_auth_token = kwargs.pop("use_auth_token", None)
- trust_remote_code = kwargs.pop("trust_remote_code", None)
- from_pipeline = kwargs.pop("_from_pipeline", None)
- from_auto_class = kwargs.pop("_from_auto", False)
- _do_init = kwargs.pop("_do_init", True)
- subfolder = kwargs.pop("subfolder", "")
- commit_hash = kwargs.pop("_commit_hash", None)
- # Not relevant for Flax Models
- _ = kwargs.pop("adapter_kwargs", None)
- if use_auth_token is not None:
- warnings.warn(
- "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
- FutureWarning,
- )
- if token is not None:
- raise ValueError(
- "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
- )
- token = use_auth_token
- if trust_remote_code is True:
- logger.warning(
- "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is"
- " ignored."
- )
- user_agent = {"file_type": "model", "framework": "flax", "from_auto_class": from_auto_class}
- if from_pipeline is not None:
- user_agent["using_pipeline"] = from_pipeline
- if is_offline_mode() and not local_files_only:
- logger.info("Offline mode: forcing local_files_only=True")
- local_files_only = True
- # Load config if we don't provide a configuration
- if not isinstance(config, PretrainedConfig):
- config_path = config if config is not None else pretrained_model_name_or_path
- config, model_kwargs = cls.config_class.from_pretrained(
- config_path,
- cache_dir=cache_dir,
- return_unused_kwargs=True,
- force_download=force_download,
- resume_download=resume_download,
- proxies=proxies,
- local_files_only=local_files_only,
- token=token,
- revision=revision,
- subfolder=subfolder,
- _from_auto=from_auto_class,
- _from_pipeline=from_pipeline,
- _commit_hash=commit_hash,
- **kwargs,
- )
- else:
- model_kwargs = kwargs.copy()
- if commit_hash is None:
- commit_hash = getattr(config, "_commit_hash", None)
- # Add the dtype to model_kwargs
- model_kwargs["dtype"] = dtype
- # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
- # index of the files.
- is_sharded = False
- # Load model
- if pretrained_model_name_or_path is not None:
- pretrained_model_name_or_path = str(pretrained_model_name_or_path)
- is_local = os.path.isdir(pretrained_model_name_or_path)
- if is_local:
- if os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)):
- # Load from a Flax checkpoint
- archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)
- elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_INDEX_NAME)):
- # Load from a sharded Flax checkpoint
- archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_INDEX_NAME)
- is_sharded = True
- elif is_safetensors_available() and os.path.isfile(
- os.path.join(pretrained_model_name_or_path, subfolder, SAFE_WEIGHTS_NAME)
- ):
- # Load from a safetensors checkpoint
- archive_file = os.path.join(pretrained_model_name_or_path, subfolder, SAFE_WEIGHTS_NAME)
- elif is_safetensors_available() and os.path.isfile(
- os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME)
- ):
- # Load from a safetensors checkpoint
- archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME)
- elif from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)):
- # Load from a PyTorch checkpoint
- archive_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)
- elif from_pt and os.path.isfile(
- os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_INDEX_NAME)
- ):
- # Load from a sharded pytorch checkpoint
- archive_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_INDEX_NAME)
- is_sharded = True
- # At this stage we don't have a weight file so we will raise an error.
- elif is_safetensors_available() and os.path.isfile(
- os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
- ):
- # Load from a sharded safetensors checkpoint
- archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
- is_sharded = True
- raise NotImplementedError("Support for sharded checkpoints using safetensors is coming soon!")
- elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)):
- raise OSError(
- f"Error no file named {FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} "
- "but there is a file for PyTorch weights. Use `from_pt=True` to load this model from those "
- "weights."
- )
- else:
- raise OSError(
- f"Error no file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
- f"{pretrained_model_name_or_path}."
- )
- elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)):
- archive_file = pretrained_model_name_or_path
- is_local = True
- elif is_remote_url(pretrained_model_name_or_path):
- filename = pretrained_model_name_or_path
- resolved_archive_file = download_url(pretrained_model_name_or_path)
- else:
- if from_pt:
- filename = WEIGHTS_NAME
- else:
- filename = FLAX_WEIGHTS_NAME
- try:
- # Load from URL or cache if already cached
- cached_file_kwargs = {
- "cache_dir": cache_dir,
- "force_download": force_download,
- "proxies": proxies,
- "resume_download": resume_download,
- "local_files_only": local_files_only,
- "token": token,
- "user_agent": user_agent,
- "revision": revision,
- "subfolder": subfolder,
- "_raise_exceptions_for_gated_repo": False,
- "_raise_exceptions_for_missing_entries": False,
- "_commit_hash": commit_hash,
- }
- resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
- # Maybe the checkpoint is sharded, we try to grab the index name in this case.
- if resolved_archive_file is None and filename == FLAX_WEIGHTS_NAME:
- resolved_archive_file = cached_file(
- pretrained_model_name_or_path, FLAX_WEIGHTS_INDEX_NAME, **cached_file_kwargs
- )
- if resolved_archive_file is not None:
- is_sharded = True
- # Maybe the checkpoint is pytorch sharded, we try to grab the pytorch index name in this case.
- if resolved_archive_file is None and from_pt:
- resolved_archive_file = cached_file(
- pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **cached_file_kwargs
- )
- if resolved_archive_file is not None:
- is_sharded = True
- # If we still haven't found anything, look for `safetensors`.
- if resolved_archive_file is None:
- # No support for sharded safetensors yet, so we'll raise an error if that's all we find.
- filename = SAFE_WEIGHTS_NAME
- resolved_archive_file = cached_file(
- pretrained_model_name_or_path, SAFE_WEIGHTS_NAME, **cached_file_kwargs
- )
- # Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None
- # result when internet is up, the repo and revision exist, but the file does not.
- if resolved_archive_file is None:
- # Otherwise, maybe there is a TF or Torch model file. We try those to give a helpful error
- # message.
- has_file_kwargs = {
- "revision": revision,
- "proxies": proxies,
- "token": token,
- "cache_dir": cache_dir,
- "local_files_only": local_files_only,
- }
- if has_file(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME, **has_file_kwargs):
- is_sharded = True
- raise NotImplementedError(
- "Support for sharded checkpoints using safetensors is coming soon!"
- )
- elif has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs):
- raise OSError(
- f"{pretrained_model_name_or_path} does not appear to have a file named"
- f" {FLAX_WEIGHTS_NAME} but there is a file for PyTorch weights. Use `from_pt=True` to"
- " load this model from those weights."
- )
- elif has_file(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **has_file_kwargs):
- raise OSError(
- f"{pretrained_model_name_or_path} does not appear to have a file named"
- f" {FLAX_WEIGHTS_INDEX_NAME} but there is a sharded file for PyTorch weights. Use"
- " `from_pt=True` to load this model from those weights."
- )
- else:
- raise OSError(
- f"{pretrained_model_name_or_path} does not appear to have a file named"
- f" {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}."
- )
- except OSError:
- # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted
- # to the original exception.
- raise
- except Exception:
- # For any other exception, we throw a generic error.
- raise OSError(
- f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it"
- " from 'https://huggingface.co/models', make sure you don't have a local directory with the"
- f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
- f" directory containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}."
- )
- if is_local:
- logger.info(f"loading weights file {archive_file}")
- resolved_archive_file = archive_file
- filename = resolved_archive_file.split(os.path.sep)[-1]
- else:
- logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}")
- else:
- resolved_archive_file = None
- # We'll need to download and cache each checkpoint shard if the checkpoint is sharded.
- if is_sharded:
- # resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case.
- resolved_archive_file, _ = get_checkpoint_shard_files(
- pretrained_model_name_or_path,
- resolved_archive_file,
- cache_dir=cache_dir,
- force_download=force_download,
- proxies=proxies,
- resume_download=resume_download,
- local_files_only=local_files_only,
- token=token,
- user_agent=user_agent,
- revision=revision,
- subfolder=subfolder,
- _commit_hash=commit_hash,
- )
- safetensors_from_pt = False
- if filename == SAFE_WEIGHTS_NAME:
- with safe_open(resolved_archive_file, framework="flax") as f:
- safetensors_metadata = f.metadata()
- if safetensors_metadata is None or safetensors_metadata.get("format") not in ["pt", "tf", "flax"]:
- raise OSError(
- f"The safetensors archive passed at {resolved_archive_file} does not contain the valid metadata."
- " Make sure you save your model with the `save_pretrained` method."
- )
- safetensors_from_pt = safetensors_metadata.get("format") == "pt"
- # init random models
- model = cls(config, *model_args, _do_init=_do_init, **model_kwargs)
- if from_pt or safetensors_from_pt:
- state = load_pytorch_checkpoint_in_flax_state_dict(model, resolved_archive_file, is_sharded)
- else:
- if is_sharded:
- state = cls.load_flax_sharded_weights(resolved_archive_file)
- else:
- state = cls.load_flax_weights(resolved_archive_file)
- # make sure all arrays are stored as jnp.arrays
- # NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4:
- # https://github.com/google/flax/issues/1261
- if _do_init:
- state = jax.tree_util.tree_map(jnp.array, state)
- else:
- # keep the params on CPU if we don't want to initialize
- state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.local_devices(backend="cpu")[0]), state)
- if "batch_stats" in state: # if flax model contains batch norm layers
- # if model is base model only use model_prefix key
- if (
- cls.base_model_prefix not in dict(model.params_shape_tree["params"])
- and cls.base_model_prefix in state["params"]
- ):
- state["params"] = state["params"][cls.base_model_prefix]
- state["batch_stats"] = state["batch_stats"][cls.base_model_prefix]
- # if model is head model and we are loading weights from base model
- # we initialize new params dict with base_model_prefix
- if (
- cls.base_model_prefix in dict(model.params_shape_tree["params"])
- and cls.base_model_prefix not in state["params"]
- ):
- state = {
- "params": {cls.base_model_prefix: state["params"]},
- "batch_stats": {cls.base_model_prefix: state["batch_stats"]},
- }
- else:
- # if model is base model only use model_prefix key
- if cls.base_model_prefix not in dict(model.params_shape_tree) and cls.base_model_prefix in state:
- state = state[cls.base_model_prefix]
- # if model is head model and we are loading weights from base model
- # we initialize new params dict with base_model_prefix
- if cls.base_model_prefix in dict(model.params_shape_tree) and cls.base_model_prefix not in state:
- state = {cls.base_model_prefix: state}
- # flatten dicts
- state = flatten_dict(state)
- random_state = flatten_dict(unfreeze(model.params if _do_init else model.params_shape_tree))
- missing_keys = model.required_params - set(state.keys())
- unexpected_keys = set(state.keys()) - model.required_params
- # Disabling warning when porting pytorch weights to flax, flax does not uses num_batches_tracked
- for unexpected_key in unexpected_keys.copy():
- if "num_batches_tracked" in unexpected_key[-1]:
- unexpected_keys.remove(unexpected_key)
- if missing_keys and not _do_init:
- logger.warning(
- f"The checkpoint {pretrained_model_name_or_path} is missing required keys: {missing_keys}. "
- "Make sure to call model.init_weights to initialize the missing weights."
- )
- cls._missing_keys = missing_keys
- # Mismatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
- # matching the weights in the model.
- mismatched_keys = []
- for key in state:
- if key in random_state and state[key].shape != random_state[key].shape:
- if ignore_mismatched_sizes:
- mismatched_keys.append((key, state[key].shape, random_state[key].shape))
- state[key] = random_state[key]
- else:
- raise ValueError(
- f"Trying to load the pretrained weight for {key} failed: checkpoint has shape "
- f"{state[key].shape} which is incompatible with the model shape {random_state[key].shape}. "
- "Using `ignore_mismatched_sizes=True` if you really want to load this checkpoint inside this "
- "model."
- )
- # add missing keys as random parameters if we are initializing
- if missing_keys and _do_init:
- for missing_key in missing_keys:
- state[missing_key] = random_state[missing_key]
- # remove unexpected keys to not be saved again
- for unexpected_key in unexpected_keys:
- del state[unexpected_key]
- if len(unexpected_keys) > 0:
- logger.warning(
- f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
- f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
- f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
- " with another architecture (e.g. initializing a BertForSequenceClassification model from a"
- " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
- f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical"
- " (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
- )
- else:
- logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
- if len(missing_keys) > 0:
- logger.warning(
- f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
- f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
- " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
- )
- elif len(mismatched_keys) == 0:
- logger.info(
- f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
- f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint"
- f" was trained on, you can already use {model.__class__.__name__} for predictions without further"
- " training."
- )
- if len(mismatched_keys) > 0:
- mismatched_warning = "\n".join(
- [
- f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
- for key, shape1, shape2 in mismatched_keys
- ]
- )
- logger.warning(
- f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
- f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
- f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able"
- " to use it for predictions and inference."
- )
- # dictionary of key: dtypes for the model params
- param_dtypes = jax.tree_util.tree_map(lambda x: x.dtype, state)
- # extract keys of parameters not in jnp.float32
- fp16_params = [k for k in param_dtypes if param_dtypes[k] == jnp.float16]
- bf16_params = [k for k in param_dtypes if param_dtypes[k] == jnp.bfloat16]
- # raise a warning if any of the parameters are not in jnp.float32
- if len(fp16_params) > 0:
- logger.warning(
- f"Some of the weights of {model.__class__.__name__} were initialized in float16 precision from "
- f"the model checkpoint at {pretrained_model_name_or_path}:\n{fp16_params}\n"
- "You should probably UPCAST the model weights to float32 if this was not intended. "
- "See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this."
- )
- if len(bf16_params) > 0:
- logger.warning(
- f"Some of the weights of {model.__class__.__name__} were initialized in bfloat16 precision from "
- f"the model checkpoint at {pretrained_model_name_or_path}:\n{bf16_params}\n"
- "You should probably UPCAST the model weights to float32 if this was not intended. "
- "See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this."
- )
- # If it is a model with generation capabilities, attempt to load the generation config
- if model.can_generate():
- try:
- model.generation_config = GenerationConfig.from_pretrained(
- pretrained_model_name_or_path,
- cache_dir=cache_dir,
- force_download=force_download,
- resume_download=resume_download,
- proxies=proxies,
- local_files_only=local_files_only,
- token=token,
- revision=revision,
- subfolder=subfolder,
- _from_auto=from_auto_class,
- _from_pipeline=from_pipeline,
- **kwargs,
- )
- except OSError:
- logger.info(
- "Generation config file not found, using a generation config created from the model config."
- )
- pass
- if _do_init:
- # set correct parameters
- model.params = unflatten_dict(state)
- return model
- else:
- return model, unflatten_dict(state)
- def save_pretrained(
- self,
- save_directory: Union[str, os.PathLike],
- params=None,
- push_to_hub=False,
- max_shard_size="10GB",
- token: Optional[Union[str, bool]] = None,
- safe_serialization: bool = False,
- **kwargs,
- ):
- """
- Save a model and its configuration file to a directory, so that it can be re-loaded using the
- `[`~FlaxPreTrainedModel.from_pretrained`]` class method
- Arguments:
- save_directory (`str` or `os.PathLike`):
- Directory to which to save. Will be created if it doesn't exist.
- push_to_hub (`bool`, *optional*, defaults to `False`):
- Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
- repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
- namespace).
- max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
- The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
- lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`).
- <Tip warning={true}>
- If a single weight of the model is bigger than `max_shard_size`, it will be in its own checkpoint shard
- which will be bigger than `max_shard_size`.
- </Tip>
- token (`str` or `bool`, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
- the token generated when running `hf auth login` (stored in `~/.huggingface`).
- kwargs (`dict[str, Any]`, *optional*):
- Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
- safe_serialization (`bool`, *optional*, defaults to `False`):
- Whether to save the model using `safetensors` or through msgpack.
- """
- use_auth_token = kwargs.pop("use_auth_token", None)
- if use_auth_token is not None:
- warnings.warn(
- "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
- FutureWarning,
- )
- if token is not None:
- raise ValueError(
- "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
- )
- token = use_auth_token
- if token is not None:
- kwargs["token"] = token
- if os.path.isfile(save_directory):
- logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
- return
- os.makedirs(save_directory, exist_ok=True)
- if push_to_hub:
- commit_message = kwargs.pop("commit_message", None)
- repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
- repo_id = self._create_repo(repo_id, **kwargs)
- files_timestamps = self._get_files_timestamps(save_directory)
- # get abs dir
- save_directory = os.path.abspath(save_directory)
- # save config as well
- self.config.architectures = [self.__class__.__name__[4:]]
- # If we have a custom model, we copy the file defining it in the folder and set the attributes so it can be
- # loaded from the Hub.
- if self._auto_class is not None:
- custom_object_save(self, save_directory, config=self.config)
- self.config.save_pretrained(save_directory)
- if self.can_generate():
- self.generation_config.save_pretrained(save_directory)
- # save model
- weights_name = SAFE_WEIGHTS_NAME if safe_serialization else FLAX_WEIGHTS_NAME
- output_model_file = os.path.join(save_directory, weights_name)
- shards, index = flax_shard_checkpoint(params if params is not None else self.params, max_shard_size)
- # Clean the folder from a previous save
- for filename in os.listdir(save_directory):
- full_filename = os.path.join(save_directory, filename)
- weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "")
- if filename.startswith(weights_no_suffix) and os.path.isfile(full_filename) and filename not in shards:
- os.remove(full_filename)
- if index is None:
- if safe_serialization:
- params = params if params is not None else self.params
- flat_dict = flatten_dict(params, sep=".")
- safe_save_file(flat_dict, output_model_file, metadata={"format": "flax"})
- else:
- with open(output_model_file, "wb") as f:
- params = params if params is not None else self.params
- model_bytes = to_bytes(params)
- f.write(model_bytes)
- else:
- save_index_file = os.path.join(save_directory, FLAX_WEIGHTS_INDEX_NAME)
- # Save the index as well
- with open(save_index_file, "w", encoding="utf-8") as f:
- content = json.dumps(index, indent=2, sort_keys=True) + "\n"
- f.write(content)
- logger.info(
- f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
- f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the "
- f"index located at {save_index_file}."
- )
- for shard_file, shard in shards.items():
- # the shard item are unflattened, to save them we need to flatten them again
- with open(os.path.join(save_directory, shard_file), mode="wb") as f:
- params = unflatten_dict(shard, sep="/")
- shard_bytes = to_bytes(params)
- f.write(shard_bytes)
- logger.info(f"Model weights saved in {output_model_file}")
- if push_to_hub:
- self._upload_modified_files(
- save_directory,
- repo_id,
- files_timestamps,
- commit_message=commit_message,
- token=token,
- )
- @classmethod
- def register_for_auto_class(cls, auto_class="FlaxAutoModel"):
- """
- Register this class with a given auto class. This should only be used for custom models as the ones in the
- library are already mapped with an auto class.
- Args:
- auto_class (`str` or `type`, *optional*, defaults to `"FlaxAutoModel"`):
- The auto class to register this new model with.
- """
- if not isinstance(auto_class, str):
- auto_class = auto_class.__name__
- import transformers.models.auto as auto_module
- if not hasattr(auto_module, auto_class):
- raise ValueError(f"{auto_class} is not a valid auto class.")
- cls._auto_class = auto_class
- # To update the docstring, we need to copy the method, otherwise we change the original docstring.
- FlaxPreTrainedModel.push_to_hub = copy_func(FlaxPreTrainedModel.push_to_hub)
- if FlaxPreTrainedModel.push_to_hub.__doc__ is not None:
- FlaxPreTrainedModel.push_to_hub.__doc__ = FlaxPreTrainedModel.push_to_hub.__doc__.format(
- object="model", object_class="FlaxAutoModel", object_files="model checkpoint"
- )
- def overwrite_call_docstring(model_class, docstring):
- # copy __call__ function to be sure docstring is changed only for this function
- model_class.__call__ = copy_func(model_class.__call__)
- # delete existing docstring
- model_class.__call__.__doc__ = None
- # set correct docstring
- model_class.__call__ = add_start_docstrings_to_model_forward(docstring)(model_class.__call__)
- def append_call_sample_docstring(
- model_class, checkpoint, output_type, config_class, mask=None, revision=None, real_checkpoint=None
- ):
- model_class.__call__ = copy_func(model_class.__call__)
- model_class.__call__ = add_code_sample_docstrings(
- checkpoint=checkpoint,
- output_type=output_type,
- config_class=config_class,
- model_cls=model_class.__name__,
- revision=revision,
- real_checkpoint=real_checkpoint,
- )(model_class.__call__)
- def append_replace_return_docstrings(model_class, output_type, config_class):
- model_class.__call__ = copy_func(model_class.__call__)
- model_class.__call__ = replace_return_docstrings(
- output_type=output_type,
- config_class=config_class,
- )(model_class.__call__)
|