modeling_tf_utils.py 162 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878287928802881288228832884288528862887288828892890289128922893289428952896289728982899290029012902290329042905290629072908290929102911291229132914291529162917291829192920292129222923292429252926292729282929293029312932293329342935293629372938293929402941294229432944294529462947294829492950295129522953295429552956295729582959296029612962296329642965296629672968296929702971297229732974297529762977297829792980298129822983298429852986298729882989299029912992299329942995299629972998299930003001300230033004300530063007300830093010301130123013301430153016301730183019302030213022302330243025302630273028302930303031303230333034303530363037303830393040304130423043304430453046304730483049305030513052305330543055305630573058305930603061306230633064306530663067306830693070307130723073307430753076307730783079308030813082308330843085308630873088308930903091309230933094309530963097309830993100310131023103310431053106310731083109311031113112311331143115311631173118311931203121312231233124312531263127312831293130313131323133313431353136313731383139314031413142314331443145314631473148314931503151315231533154315531563157315831593160316131623163316431653166316731683169317031713172317331743175317631773178317931803181318231833184318531863187318831893190319131923193319431953196319731983199320032013202320332043205320632073208320932103211321232133214321532163217321832193220322132223223322432253226322732283229323032313232323332343235323632373238323932403241324232433244324532463247324832493250325132523253325432553256325732583259326032613262326332643265326632673268326932703271327232733274327532763277327832793280328132823283328432853286328732883289329032913292329332943295329632973298329933003301330233033304330533063307330833093310331133123313331433153316331733183319332033213322332333243325332633273328332933303331333233333334333533363337333833393340334133423343334433453346334733483349335033513352335333543355335633573358335933603361336233633364336533663367336833693370337133723373337433753376337733783379338033813382338333843385338633873388338933903391339233933394339533963397339833993400340134023403340434053406340734083409341034113412341334143415341634173418341934203421342234233424342534263427342834293430343134323433343434353436343734383439344034413442344334443445344634473448344934503451345234533454345534563457345834593460346134623463346434653466346734683469347034713472347334743475347634773478347934803481348234833484348534863487348834893490349134923493349434953496349734983499350035013502350335043505350635073508350935103511351235133514351535163517351835193520352135223523352435253526352735283529
  1. # coding=utf-8
  2. # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
  3. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """TF general model utils."""
  17. from __future__ import annotations
  18. import functools
  19. import gc
  20. import inspect
  21. import json
  22. import os
  23. import pickle
  24. import re
  25. import warnings
  26. from collections.abc import Mapping
  27. from pathlib import Path
  28. from typing import TYPE_CHECKING, Any, Callable, Union
  29. import h5py
  30. import numpy as np
  31. import tensorflow as tf
  32. from packaging.version import parse
  33. from . import DataCollatorWithPadding, DefaultDataCollator
  34. from .activations_tf import get_tf_activation
  35. from .configuration_utils import PretrainedConfig
  36. from .dynamic_module_utils import custom_object_save
  37. from .generation import GenerationConfig, TFGenerationMixin
  38. from .tf_utils import (
  39. convert_batch_encoding,
  40. expand_1d,
  41. load_attributes_from_hdf5_group,
  42. save_attributes_to_hdf5_group,
  43. shape_list,
  44. )
  45. from .utils import (
  46. SAFE_WEIGHTS_INDEX_NAME,
  47. SAFE_WEIGHTS_NAME,
  48. TF2_WEIGHTS_INDEX_NAME,
  49. TF2_WEIGHTS_NAME,
  50. TF_WEIGHTS_NAME,
  51. WEIGHTS_INDEX_NAME,
  52. WEIGHTS_NAME,
  53. ModelOutput,
  54. PushToHubMixin,
  55. cached_file,
  56. download_url,
  57. find_labels,
  58. has_file,
  59. is_offline_mode,
  60. is_remote_url,
  61. is_safetensors_available,
  62. is_tf_symbolic_tensor,
  63. logging,
  64. requires_backends,
  65. working_or_temp_dir,
  66. )
  67. from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
  68. if is_safetensors_available():
  69. from safetensors import safe_open
  70. from safetensors.tensorflow import save_file as safe_save_file
  71. if TYPE_CHECKING:
  72. from . import PreTrainedTokenizerBase
  73. logger = logging.get_logger(__name__)
  74. if "TF_USE_LEGACY_KERAS" not in os.environ:
  75. os.environ["TF_USE_LEGACY_KERAS"] = "1" # Compatibility fix to make sure tf.keras stays at Keras 2
  76. elif os.environ["TF_USE_LEGACY_KERAS"] != "1":
  77. logger.warning(
  78. "Transformers is only compatible with Keras 2, but you have explicitly set `TF_USE_LEGACY_KERAS` to `0`. "
  79. "This may result in unexpected behaviour or errors if Keras 3 objects are passed to Transformers models."
  80. )
  81. try:
  82. import tf_keras as keras
  83. from tf_keras import backend as K
  84. except (ModuleNotFoundError, ImportError):
  85. import keras
  86. from keras import backend as K
  87. if parse(keras.__version__).major > 2:
  88. raise ValueError(
  89. "Your currently installed version of Keras is Keras 3, but this is not yet supported in "
  90. "Transformers. Please install the backwards-compatible tf-keras package with "
  91. "`pip install tf-keras`."
  92. )
  93. tf_logger = tf.get_logger()
  94. TFModelInputType = Union[
  95. list[tf.Tensor],
  96. list[np.ndarray],
  97. dict[str, tf.Tensor],
  98. dict[str, np.ndarray],
  99. tf.Tensor,
  100. np.ndarray,
  101. ]
  102. def dummy_loss(y_true, y_pred):
  103. if y_pred.shape.rank <= 1:
  104. return y_pred
  105. else:
  106. reduction_axes = list(range(1, y_pred.shape.rank))
  107. return tf.reduce_mean(y_pred, axis=reduction_axes)
  108. class TFModelUtilsMixin:
  109. """
  110. A few utilities for `keras.Model`, to be used as a mixin.
  111. """
  112. def num_parameters(self, only_trainable: bool = False) -> int:
  113. """
  114. Get the number of (optionally, trainable) parameters in the model.
  115. Args:
  116. only_trainable (`bool`, *optional*, defaults to `False`):
  117. Whether or not to return only the number of trainable parameters
  118. Returns:
  119. `int`: The number of parameters.
  120. """
  121. if only_trainable:
  122. return int(sum(np.prod(w.shape.as_list()) for w in self.trainable_variables))
  123. else:
  124. return self.count_params()
  125. def keras_serializable(cls):
  126. """
  127. Decorate a Keras Layer class to support Keras serialization.
  128. This is done by:
  129. 1. Adding a `transformers_config` dict to the Keras config dictionary in `get_config` (called by Keras at
  130. serialization time.
  131. 2. Wrapping `__init__` to accept that `transformers_config` dict (passed by Keras at deserialization time) and
  132. convert it to a config object for the actual layer initializer.
  133. 3. Registering the class as a custom object in Keras (if the Tensorflow version supports this), so that it does not
  134. need to be supplied in `custom_objects` in the call to `keras.models.load_model`.
  135. Args:
  136. cls (a `keras.layers.Layers subclass`):
  137. Typically a `TF.MainLayer` class in this project, in general must accept a `config` argument to its
  138. initializer.
  139. Returns:
  140. The same class object, with modifications for Keras deserialization.
  141. """
  142. initializer = cls.__init__
  143. config_class = getattr(cls, "config_class", None)
  144. if config_class is None:
  145. raise AttributeError("Must set `config_class` to use @keras_serializable")
  146. @functools.wraps(initializer)
  147. def wrapped_init(self, *args, **kwargs):
  148. config = args[0] if args and isinstance(args[0], PretrainedConfig) else kwargs.pop("config", None)
  149. if isinstance(config, dict):
  150. config = config_class.from_dict(config)
  151. initializer(self, config, *args, **kwargs)
  152. elif isinstance(config, PretrainedConfig):
  153. if len(args) > 0:
  154. initializer(self, *args, **kwargs)
  155. else:
  156. initializer(self, config, *args, **kwargs)
  157. else:
  158. raise TypeError("Must pass either `config` (PretrainedConfig) or `config` (dict)")
  159. self._config = config
  160. self._kwargs = kwargs
  161. cls.__init__ = wrapped_init
  162. if not hasattr(cls, "get_config"):
  163. raise TypeError("Only use @keras_serializable on keras.layers.Layer subclasses")
  164. if hasattr(cls.get_config, "_is_default"):
  165. def get_config(self):
  166. cfg = super(cls, self).get_config()
  167. cfg["config"] = self._config.to_dict()
  168. cfg.update(self._kwargs)
  169. return cfg
  170. cls.get_config = get_config
  171. cls._keras_serializable = True
  172. if hasattr(keras.utils, "register_keras_serializable"):
  173. cls = keras.utils.register_keras_serializable()(cls)
  174. return cls
  175. class TFCausalLanguageModelingLoss:
  176. """
  177. Loss function suitable for causal language modeling (CLM), that is, the task of guessing the next token.
  178. <Tip>
  179. Any label of -100 will be ignored (along with the corresponding logits) in the loss computation.
  180. </Tip>
  181. """
  182. def hf_compute_loss(self, labels, logits):
  183. loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=keras.losses.Reduction.NONE)
  184. if self.config.tf_legacy_loss:
  185. # make sure only labels that are not equal to -100 affect the loss
  186. active_loss = tf.not_equal(tf.reshape(labels, (-1,)), -100)
  187. reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
  188. labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)
  189. return loss_fn(labels, reduced_logits)
  190. # Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway
  191. unmasked_loss = loss_fn(tf.nn.relu(labels), logits)
  192. # make sure only labels that are not equal to -100 affect the loss
  193. loss_mask = tf.cast(labels != -100, dtype=unmasked_loss.dtype)
  194. masked_loss = unmasked_loss * loss_mask
  195. reduced_masked_loss = tf.reduce_sum(masked_loss) / tf.reduce_sum(loss_mask)
  196. return tf.reshape(reduced_masked_loss, (1,))
  197. class TFQuestionAnsweringLoss:
  198. """
  199. Loss function suitable for question answering.
  200. """
  201. def hf_compute_loss(self, labels, logits):
  202. loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=keras.losses.Reduction.NONE)
  203. start_loss = loss_fn(labels["start_position"], logits[0])
  204. end_loss = loss_fn(labels["end_position"], logits[1])
  205. return (start_loss + end_loss) / 2.0
  206. class TFTokenClassificationLoss:
  207. """
  208. Loss function suitable for token classification.
  209. <Tip>
  210. Any label of -100 will be ignored (along with the corresponding logits) in the loss computation.
  211. </Tip>
  212. """
  213. def hf_compute_loss(self, labels, logits):
  214. loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=keras.losses.Reduction.NONE)
  215. if tf.executing_eagerly(): # Data-dependent conditionals are forbidden in XLA
  216. if tf.math.reduce_any(labels == -1):
  217. tf.print("Using `-1` to mask the loss for the token is deprecated. Please use `-100` instead.")
  218. if self.config.tf_legacy_loss:
  219. # make sure only labels that are not equal to -100
  220. # are taken into account as loss
  221. if tf.math.reduce_any(labels == -1):
  222. tf.print("Using `-1` to mask the loss for the token is deprecated. Please use `-100` instead.")
  223. active_loss = tf.reshape(labels, (-1,)) != -1
  224. else:
  225. active_loss = tf.reshape(labels, (-1,)) != -100
  226. reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
  227. labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)
  228. return loss_fn(labels, reduced_logits)
  229. # Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway
  230. unmasked_loss = loss_fn(tf.nn.relu(labels), logits)
  231. # make sure only labels that are not equal to -100 or -1
  232. # are taken into account as loss
  233. loss_mask = tf.cast(labels >= 0, dtype=unmasked_loss.dtype)
  234. # Avoid possible division by zero later
  235. # Masked positions will have a loss of NaN because -100 and -1 are not valid labels
  236. masked_loss = unmasked_loss * loss_mask
  237. reduced_masked_loss = tf.reduce_sum(masked_loss) / tf.reduce_sum(loss_mask)
  238. return tf.reshape(reduced_masked_loss, (1,))
  239. class TFSequenceClassificationLoss:
  240. """
  241. Loss function suitable for sequence classification.
  242. """
  243. def hf_compute_loss(self, labels, logits):
  244. if logits.shape.rank == 1 or logits.shape[1] == 1:
  245. loss_fn = keras.losses.MeanSquaredError(reduction=keras.losses.Reduction.NONE)
  246. if labels.shape.rank == 1:
  247. # MeanSquaredError returns a scalar loss if the labels are 1D, so avoid that
  248. labels = tf.expand_dims(labels, axis=-1)
  249. else:
  250. loss_fn = keras.losses.SparseCategoricalCrossentropy(
  251. from_logits=True, reduction=keras.losses.Reduction.NONE
  252. )
  253. return loss_fn(labels, logits)
  254. class TFMultipleChoiceLoss:
  255. """Loss function suitable for multiple choice tasks."""
  256. def hf_compute_loss(self, labels, logits):
  257. loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=keras.losses.Reduction.NONE)
  258. return loss_fn(labels, logits)
  259. class TFMaskedLanguageModelingLoss(TFCausalLanguageModelingLoss):
  260. """
  261. Loss function suitable for masked language modeling (MLM), that is, the task of guessing the masked tokens.
  262. <Tip>
  263. Any label of -100 will be ignored (along with the corresponding logits) in the loss computation.
  264. </Tip>
  265. """
  266. class TFNextSentencePredictionLoss:
  267. """
  268. Loss function suitable for next sentence prediction (NSP), that is, the task of guessing the next sentence.
  269. <Tip>
  270. Any label of -100 will be ignored (along with the corresponding logits) in the loss computation.
  271. </Tip>
  272. """
  273. def hf_compute_loss(self, labels, logits):
  274. loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=keras.losses.Reduction.NONE)
  275. if self.config.tf_legacy_loss:
  276. # make sure only labels that are not equal to -100
  277. # are taken into account as loss
  278. next_sentence_active_loss = tf.not_equal(tf.reshape(labels, (-1,)), -100)
  279. next_sentence_reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, 2)), next_sentence_active_loss)
  280. next_sentence_label = tf.boolean_mask(tf.reshape(labels, (-1,)), next_sentence_active_loss)
  281. return loss_fn(next_sentence_label, next_sentence_reduced_logits)
  282. # make sure only labels that are not equal to -100
  283. # are taken into account as loss
  284. # Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway
  285. unmasked_ns_loss = loss_fn(y_true=tf.nn.relu(labels), y_pred=logits)
  286. ns_loss_mask = tf.cast(labels != -100, dtype=unmasked_ns_loss.dtype)
  287. # Just zero out samples where label is -100, no reduction
  288. masked_ns_loss = unmasked_ns_loss * ns_loss_mask
  289. return masked_ns_loss
  290. def booleans_processing(config, **kwargs):
  291. """
  292. Process the input booleans of each model.
  293. Args:
  294. config ([`PretrainedConfig`]):
  295. The config of the running model.
  296. **kwargs:
  297. The boolean parameters
  298. Returns:
  299. A dictionary with the proper values for each boolean
  300. """
  301. final_booleans = {}
  302. # Pure conv models (such as ConvNext) do not have `output_attentions`. If the signature has
  303. # `output_attentions`, it will be present here in `kwargs`, even if unset (in that case, as `None`)
  304. if "output_attentions" in kwargs:
  305. final_booleans["output_attentions"] = (
  306. kwargs["output_attentions"] if kwargs["output_attentions"] is not None else config.output_attentions
  307. )
  308. final_booleans["output_hidden_states"] = (
  309. kwargs["output_hidden_states"] if kwargs["output_hidden_states"] is not None else config.output_hidden_states
  310. )
  311. final_booleans["return_dict"] = kwargs["return_dict"] if kwargs["return_dict"] is not None else config.return_dict
  312. if "use_cache" in kwargs:
  313. final_booleans["use_cache"] = (
  314. kwargs["use_cache"] if kwargs["use_cache"] is not None else getattr(config, "use_cache", None)
  315. )
  316. return final_booleans
  317. def unpack_inputs(func):
  318. """
  319. Decorator that processes the inputs to a Keras layer, passing them to the layer as keyword arguments. This enables
  320. downstream use of the inputs by their variable name, even if they arrive packed as a dictionary in the first input
  321. (common case in Keras).
  322. Args:
  323. func (`callable`):
  324. The callable function of the TensorFlow model.
  325. Returns:
  326. A callable that wraps the original `func` with the behavior described above.
  327. """
  328. original_signature = inspect.signature(func)
  329. @functools.wraps(func)
  330. def run_call_with_unpacked_inputs(self, *args, **kwargs):
  331. # isolates the actual `**kwargs` for the decorated function
  332. kwargs_call = {key: val for key, val in kwargs.items() if key not in dict(original_signature.parameters)}
  333. fn_args_and_kwargs = {key: val for key, val in kwargs.items() if key not in kwargs_call}
  334. fn_args_and_kwargs.update({"kwargs_call": kwargs_call})
  335. # move any arg into kwargs, if they exist
  336. fn_args_and_kwargs.update(dict(zip(func.__code__.co_varnames[1:], args)))
  337. # Encoder Decoder models delegate the application of the configuration options to their inner models.
  338. if "EncoderDecoder" in self.__class__.__name__:
  339. config = None
  340. else:
  341. config = self.config
  342. unpacked_inputs = input_processing(func, config, **fn_args_and_kwargs)
  343. return func(self, **unpacked_inputs)
  344. # Keras enforces the first layer argument to be passed, and checks it through `inspect.getfullargspec()`. This
  345. # function does not follow wrapper chains (i.e. ignores `functools.wraps()`), meaning that without the line below
  346. # Keras would attempt to check the first argument against the literal signature of the wrapper.
  347. run_call_with_unpacked_inputs.__signature__ = original_signature
  348. return run_call_with_unpacked_inputs
  349. def input_processing(func, config, **kwargs):
  350. """
  351. Process the input of each TensorFlow model including the booleans. In case of a list of symbolic inputs, each input
  352. has to be named accordingly to the parameters name, i.e. `input_ids = keras.Input(shape=(128,), dtype='int32',
  353. name="input_ids")` otherwise the order of the tensors will not be guaranteed during the training.
  354. Args:
  355. func (`callable`):
  356. The callable function of the TensorFlow model.
  357. config ([`PretrainedConfig`]):
  358. The config of the running model.
  359. **kwargs:
  360. The inputs of the model.
  361. Returns:
  362. Two lists, one for the missing layers, and another one for the unexpected layers.
  363. """
  364. signature = dict(inspect.signature(func).parameters)
  365. has_kwargs = bool(signature.pop("kwargs", None))
  366. signature.pop("self", None)
  367. parameter_names = list(signature.keys())
  368. main_input_name = parameter_names[0]
  369. main_input = kwargs.pop(main_input_name, None)
  370. output = {}
  371. allowed_types = (tf.Tensor, bool, int, ModelOutput, tuple, list, dict, np.ndarray)
  372. if "inputs" in kwargs["kwargs_call"]:
  373. warnings.warn(
  374. "The `inputs` argument is deprecated and will be removed in a future version, use `input_ids` instead.",
  375. FutureWarning,
  376. )
  377. output["input_ids"] = kwargs["kwargs_call"].pop("inputs")
  378. if "decoder_cached_states" in kwargs["kwargs_call"]:
  379. warnings.warn(
  380. "The `decoder_cached_states` argument is deprecated and will be removed in a future version, use"
  381. " `past_key_values` instead.",
  382. FutureWarning,
  383. )
  384. output["past_key_values"] = kwargs["kwargs_call"].pop("decoder_cached_states")
  385. if "past" in kwargs["kwargs_call"] and "past_key_values" in parameter_names:
  386. warnings.warn(
  387. "The `past` argument is deprecated and will be removed in a future version, use `past_key_values`"
  388. " instead.",
  389. FutureWarning,
  390. )
  391. kwargs["past_key_values"] = kwargs["kwargs_call"].pop("past")
  392. elif "past_key_values" in kwargs["kwargs_call"] and "past" in parameter_names:
  393. kwargs["past"] = kwargs["kwargs_call"].pop("past_key_values")
  394. if has_kwargs:
  395. output["kwargs"] = kwargs.pop("kwargs_call", {})
  396. else:
  397. if len(kwargs["kwargs_call"]) > 0:
  398. raise ValueError(
  399. "The following keyword arguments are not supported by this model:"
  400. f" {list(kwargs['kwargs_call'].keys())}."
  401. )
  402. kwargs.pop("kwargs_call")
  403. for k, v in kwargs.items():
  404. if isinstance(v, allowed_types) or tf.is_tensor(v) or v is None:
  405. output[k] = v
  406. else:
  407. raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.")
  408. if isinstance(main_input, (tuple, list)):
  409. for i, input in enumerate(main_input):
  410. # EagerTensors don't allow to use the .name property so we check for a real Tensor
  411. if is_tf_symbolic_tensor(input):
  412. # Tensor names have always the pattern `name:id` then we check only the
  413. # `name` part
  414. tensor_name = input.name.split(":")[0]
  415. if tensor_name in parameter_names:
  416. output[tensor_name] = input
  417. else:
  418. output[parameter_names[i]] = input
  419. elif isinstance(input, allowed_types) or input is None:
  420. output[parameter_names[i]] = input
  421. else:
  422. raise ValueError(
  423. f"Data of type {type(input)} is not allowed only {allowed_types} is accepted for"
  424. f" {parameter_names[i]}."
  425. )
  426. elif isinstance(main_input, Mapping):
  427. if "inputs" in main_input:
  428. warnings.warn(
  429. "The `inputs` argument is deprecated and will be removed in a future version, use `input_ids`"
  430. " instead.",
  431. FutureWarning,
  432. )
  433. output["input_ids"] = main_input.pop("inputs")
  434. if "decoder_cached_states" in main_input:
  435. warnings.warn(
  436. "The `decoder_cached_states` argument is deprecated and will be removed in a future version, use"
  437. " `past_key_values` instead.",
  438. FutureWarning,
  439. )
  440. output["past_key_values"] = main_input.pop("decoder_cached_states")
  441. for k, v in dict(main_input).items():
  442. if isinstance(v, allowed_types) or v is None:
  443. output[k] = v
  444. elif k not in parameter_names and "args" not in parameter_names:
  445. logger.warning(
  446. f"The parameter {k} does not belongs to the parameter list {parameter_names} and will be ignored."
  447. )
  448. continue
  449. else:
  450. raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.")
  451. else:
  452. if tf.is_tensor(main_input) or main_input is None:
  453. output[main_input_name] = main_input
  454. else:
  455. raise ValueError(
  456. f"Data of type {type(main_input)} is not allowed only {allowed_types} is accepted for"
  457. f" {main_input_name}."
  458. )
  459. # Populates any unspecified argument with their default value, according to the signature.
  460. for name in parameter_names:
  461. if name not in list(output.keys()) and name != "args":
  462. output[name] = kwargs.pop(name, signature[name].default)
  463. # When creating a SavedModel TF calls the method with LayerCall.__call__(args, **kwargs)
  464. # So to respect the proper output we have to add this exception
  465. if "args" in output:
  466. if output["args"] is not None and is_tf_symbolic_tensor(output["args"]):
  467. tensor_name = output["args"].name.split(":")[0]
  468. output[tensor_name] = output["args"]
  469. else:
  470. # `args` in this case is always the first parameter, then `input_ids`
  471. output["input_ids"] = output["args"]
  472. del output["args"]
  473. if "kwargs" in output:
  474. del output["kwargs"]
  475. cast_output = {}
  476. for key, val in output.items():
  477. if isinstance(val, tf.Tensor) and val.dtype == tf.int64:
  478. cast_output[key] = tf.cast(val, tf.int32)
  479. elif isinstance(val, np.ndarray) and val.dtype == np.int64:
  480. cast_output[key] = val.astype(np.int32)
  481. else:
  482. cast_output[key] = val
  483. output = cast_output
  484. del cast_output
  485. if config is not None:
  486. boolean_dict = {
  487. k: v
  488. for k, v in output.items()
  489. if k in ["return_dict", "output_attentions", "output_hidden_states", "use_cache"]
  490. }
  491. output.update(
  492. booleans_processing(
  493. config=config,
  494. **boolean_dict,
  495. )
  496. )
  497. return output
  498. def strip_model_name_and_prefix(name, _prefix=None):
  499. if _prefix is not None and name.startswith(_prefix):
  500. name = name[len(_prefix) :]
  501. if name.startswith("/"):
  502. name = name[1:]
  503. if "model." not in name and len(name.split("/")) > 1:
  504. name = "/".join(name.split("/")[1:])
  505. return name
  506. def tf_shard_checkpoint(weights, max_shard_size="10GB", weights_name: str = TF2_WEIGHTS_NAME):
  507. """
  508. Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
  509. given size.
  510. The sub-checkpoints are determined by iterating through the `state_dict` in the order of its keys, so there is no
  511. optimization made to make each sub-checkpoint as close as possible to the maximum size passed. For example, if the
  512. limit is 10GB and we have weights of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB],
  513. [6+2+2GB] and not [6+2+2GB], [6+2GB], [6GB].
  514. <Tip warning={true}>
  515. If one of the model's weight is bigger that `max_shard_size`, it will end up in its own sub-checkpoint which will
  516. have a size greater than `max_shard_size`.
  517. </Tip>
  518. Args:
  519. weights (`dict[str, tf.RessourceVariable]`): The list of tf.RessourceVariable of a model to save.
  520. max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
  521. The maximum size of each sub-checkpoint. If expressed as a string, needs to be digits followed by a unit
  522. (like `"5MB"`).
  523. """
  524. max_shard_size = convert_file_size_to_int(max_shard_size)
  525. sharded_state_dicts = []
  526. current_block = []
  527. current_block_size = 0
  528. total_size = 0
  529. for item in weights:
  530. weight_size = item.numpy().size * item.dtype.size
  531. # If this weight is going to tip up over the maximal size, we split.
  532. if current_block_size + weight_size > max_shard_size:
  533. sharded_state_dicts.append(current_block)
  534. current_block = []
  535. current_block_size = 0
  536. current_block.append(item)
  537. current_block_size += weight_size
  538. total_size += weight_size
  539. # Add the last block
  540. sharded_state_dicts.append(current_block)
  541. # If we only have one shard, we return it
  542. if len(sharded_state_dicts) == 1:
  543. return {weights_name: sharded_state_dicts[0]}, None
  544. # Otherwise, let's build the index
  545. weight_map = {}
  546. shards = {}
  547. for idx, shard in enumerate(sharded_state_dicts):
  548. shard_file = weights_name.replace(".h5", f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.h5")
  549. shard_file = shard_file.replace(
  550. ".safetensors", f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.safetensors"
  551. )
  552. shards[shard_file] = shard
  553. for weight in shard:
  554. weight_name = weight.name
  555. weight_map[weight_name] = shard_file
  556. # Add the metadata
  557. metadata = {"total_size": total_size}
  558. index = {"metadata": metadata, "weight_map": weight_map}
  559. return shards, index
  560. def load_tf_sharded_weights(model, shard_files, ignore_mismatched_sizes=False, strict=False, _prefix=None):
  561. """
  562. This is the same as `load_tf_weights` but for a sharded checkpoint. Detect missing and unexpected layers and load
  563. the TF weights from the shard file accordingly to their names and shapes.
  564. This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being
  565. loaded in the model.
  566. Args:
  567. model (`keras.models.Model`): The model in which to load the checkpoint.
  568. shard_files (`str` or `os.PathLike`): A list containing the sharded checkpoint names.
  569. ignore_mismatched_sizes`bool`, *optional`, defaults to `True`):
  570. Whether or not to ignore the mismatch between the sizes
  571. strict (`bool`, *optional*, defaults to `True`):
  572. Whether to strictly enforce that the keys in the model state dict match the keys in the sharded checkpoint.
  573. Returns:
  574. Three lists, one for the missing layers, another one for the unexpected layers, and a last one for the
  575. mismatched layers.
  576. """
  577. # Load the index
  578. unexpected_keys = set()
  579. saved_keys = set()
  580. mismatched_keys = set()
  581. # Since TF adds the name of the class to its weights, and uses the index and not the name of the layer to load
  582. # the weight, we have to get rid of the first prefix of the name of the layer.
  583. model_keys = set()
  584. model_layer_map = {}
  585. for i, k in enumerate(model.weights):
  586. layer_name = k.name
  587. if _prefix is not None and layer_name.startswith(_prefix):
  588. layer_name = layer_name[len(_prefix) :]
  589. layer_name = layer_name.lstrip("/")
  590. if not ("model." in layer_name or len(layer_name.split("/")) == 1):
  591. layer_name = "/".join(layer_name.split("/")[1:])
  592. model_keys.add(layer_name)
  593. model_layer_map[layer_name] = i
  594. for shard_file in shard_files:
  595. saved_weight_names_set, unexpected_keys_set, mismatched_keys_set = load_tf_shard(
  596. model,
  597. model_layer_map,
  598. shard_file,
  599. ignore_mismatched_sizes=ignore_mismatched_sizes,
  600. _prefix=_prefix,
  601. )
  602. saved_keys.update(saved_weight_names_set)
  603. unexpected_keys.update(unexpected_keys_set)
  604. mismatched_keys.update(mismatched_keys_set)
  605. gc.collect()
  606. missing_keys = model_keys - saved_keys
  607. if strict and (len(missing_keys) > 0 or len(unexpected_keys) > 0):
  608. error_message = f"Error(s) in loading state_dict for {model.__class__.__name__}"
  609. if len(missing_keys) > 0:
  610. str_missing_keys = ",".join([f'"{k}"' for k in missing_keys])
  611. error_message += f"\nMissing key(s): {str_missing_keys}."
  612. if len(unexpected_keys) > 0:
  613. str_unexpected_keys = ",".join([f'"{k}"' for k in unexpected_keys])
  614. error_message += f"\nMissing key(s): {str_unexpected_keys}."
  615. raise RuntimeError(error_message)
  616. return missing_keys, unexpected_keys, mismatched_keys
  617. def load_tf_shard(model, model_layer_map, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None):
  618. """
  619. Loads a shard from a sharded checkpoint file. Can be either H5 or Safetensors.
  620. Handles missing keys and unexpected keys.
  621. Args:
  622. model (`keras.models.Model`): Model in which the weights are loaded
  623. model_layer_map (`Dict`): A dictionary mapping the layer name to the index of the layer in the model.
  624. resolved_archive_file (`str`): Path to the checkpoint file from which the weights will be loaded
  625. ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`): Whether to ignore the mismatched keys
  626. Returns:
  627. `keras.models.Model`: Three lists, one for the layers that were found and successfully restored (from the
  628. shard file), one for the mismatched layers, and another one for the unexpected layers.
  629. """
  630. saved_weight_names_set = set()
  631. saved_weights = {}
  632. mismatched_keys = set()
  633. unexpected_keys = set()
  634. # Read the H5 file
  635. try:
  636. with h5py.File(resolved_archive_file, "r") as sharded_checkpoint_file:
  637. # Retrieve the name of each layer from the H5 file
  638. saved_h5_model_layers_name = set(load_attributes_from_hdf5_group(sharded_checkpoint_file, "layer_names"))
  639. weight_value_tuples = []
  640. # Compute missing and unexpected sub layers
  641. # Store the weights in list of tuples that looks like [(weight_object, value_of_weight),...]
  642. for layer_name in saved_h5_model_layers_name:
  643. h5_layer_object = sharded_checkpoint_file[layer_name]
  644. saved_weights[layer_name] = np.asarray(h5_layer_object)
  645. saved_weight_names_set.add(layer_name)
  646. if layer_name not in model_layer_map:
  647. unexpected_keys.add(layer_name)
  648. else:
  649. symbolic_weight = model.weights[model_layer_map[layer_name]]
  650. saved_weight_value = saved_weights[layer_name]
  651. # If the current weight is found
  652. if saved_weight_value is not None:
  653. # Check if the shape of the current weight and the one from the H5 file are different
  654. if K.int_shape(symbolic_weight) != saved_weight_value.shape:
  655. # If yes we reshape the weight from the H5 file accordingly to the current weight
  656. # If the two shapes are not compatible we raise an issue
  657. try:
  658. array = np.reshape(saved_weight_value, K.int_shape(symbolic_weight))
  659. except ValueError as e:
  660. if ignore_mismatched_sizes:
  661. mismatched_keys.add(
  662. (layer_name, saved_weight_value.shape, K.int_shape(symbolic_weight))
  663. )
  664. continue
  665. else:
  666. raise e
  667. else:
  668. array = saved_weight_value
  669. # We create the tuple that will be loaded and add it to the final list
  670. weight_value_tuples.append((symbolic_weight, array))
  671. K.batch_set_value(weight_value_tuples)
  672. return saved_weight_names_set, unexpected_keys, mismatched_keys
  673. except Exception as e:
  674. try:
  675. with open(resolved_archive_file) as f:
  676. if f.read().startswith("version"):
  677. raise OSError(
  678. "You seem to have cloned a repository without having git-lfs installed. Please install "
  679. "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
  680. "you cloned."
  681. )
  682. else:
  683. raise ValueError(
  684. f"Unable to locate the file {resolved_archive_file} which is necessary to load this pretrained"
  685. " model. Make sure you have saved the model properly."
  686. ) from e
  687. except (UnicodeDecodeError, ValueError):
  688. raise OSError(
  689. f"Unable to load weights from TF checkpoint file for '{resolved_archive_file}' "
  690. f"at '{resolved_archive_file}'. "
  691. "If you tried to load a TF model from a sharded checkpoint, you should try converting the model "
  692. "by loading it in pytorch and saving it locally. A conversion script should be released soon."
  693. )
  694. def load_tf_sharded_weights_from_safetensors(
  695. model, shard_files, ignore_mismatched_sizes=False, strict=False, _prefix=None
  696. ):
  697. """
  698. This is the same as `load_tf_weights_from_safetensors` but for a sharded TF-format safetensors checkpoint.
  699. Detect missing and unexpected layers and load the TF weights from the shard file accordingly to their names and
  700. shapes.
  701. This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being
  702. loaded in the model.
  703. Args:
  704. model (`keras.models.Model`): The model in which to load the checkpoint.
  705. shard_files (`str` or `os.PathLike`): A list containing the sharded checkpoint names.
  706. ignore_mismatched_sizes`bool`, *optional`, defaults to `True`):
  707. Whether or not to ignore the mismatch between the sizes
  708. strict (`bool`, *optional*, defaults to `True`):
  709. Whether to strictly enforce that the keys in the model state dict match the keys in the sharded checkpoint.
  710. Returns:
  711. Three lists, one for the missing layers, another one for the unexpected layers, and a last one for the
  712. mismatched layers.
  713. """
  714. # Load the index
  715. unexpected_keys = set()
  716. all_missing_keys = []
  717. mismatched_keys = set()
  718. for shard_file in shard_files:
  719. missing_layers, unexpected_layers, mismatched_layers = load_tf_weights_from_safetensors(
  720. model,
  721. shard_file,
  722. ignore_mismatched_sizes=ignore_mismatched_sizes,
  723. _prefix=_prefix,
  724. )
  725. all_missing_keys.append(set(missing_layers))
  726. unexpected_keys.update(unexpected_layers)
  727. mismatched_keys.update(mismatched_layers)
  728. gc.collect()
  729. missing_keys = set.intersection(*all_missing_keys)
  730. if strict and (len(missing_keys) > 0 or len(unexpected_keys) > 0):
  731. error_message = f"Error(s) in loading state_dict for {model.__class__.__name__}"
  732. if len(missing_keys) > 0:
  733. str_missing_keys = ",".join([f'"{k}"' for k in missing_keys])
  734. error_message += f"\nMissing key(s): {str_missing_keys}."
  735. if len(unexpected_keys) > 0:
  736. str_unexpected_keys = ",".join([f'"{k}"' for k in unexpected_keys])
  737. error_message += f"\nMissing key(s): {str_unexpected_keys}."
  738. raise RuntimeError(error_message)
  739. return missing_keys, unexpected_keys, mismatched_keys
  740. def load_tf_weights(model, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None):
  741. """
  742. Detect missing and unexpected layers and load the TF weights from the shard file accordingly to their names and
  743. shapes.
  744. Args:
  745. model (`keras.models.Model`):
  746. The model to load the weights into.
  747. resolved_archive_file (`str`):
  748. The location of the H5 file.
  749. ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`):
  750. Whether or not to ignore weights with shapes that don't match between the checkpoint of the model.
  751. Returns:
  752. Three lists, one for the missing layers, another one for the unexpected layers, and a last one for the
  753. mismatched layers.
  754. """
  755. if resolved_archive_file.endswith(".safetensors"):
  756. load_function = load_tf_weights_from_safetensors
  757. else:
  758. load_function = load_tf_weights_from_h5
  759. return load_function(
  760. model, resolved_archive_file, ignore_mismatched_sizes=ignore_mismatched_sizes, _prefix=_prefix
  761. )
  762. def load_tf_weights_from_h5(model, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None):
  763. mismatched_layers = []
  764. # Read the H5 file
  765. with h5py.File(resolved_archive_file, "r") as sharded_checkpoint_file:
  766. # Retrieve the name of each layer from the H5 file
  767. saved_h5_model_layers_name = set(load_attributes_from_hdf5_group(sharded_checkpoint_file, "layer_names"))
  768. # Find the missing layers from the high level list of layers
  769. missing_layers = list({layer.name for layer in model.layers} - saved_h5_model_layers_name)
  770. # Find the unexpected layers from the high level list of layers
  771. unexpected_layers = list(saved_h5_model_layers_name - {layer.name for layer in model.layers})
  772. saved_weight_names_set = set()
  773. symbolic_weights_names = set()
  774. weight_value_tuples = []
  775. # Compute missing and unexpected sub layers
  776. # Store the weights in list of tuples that looks like [(weight_object, value_of_weight),...]
  777. for layer in model.layers:
  778. # if layer_name from the H5 file belongs to the layers from the instantiated model
  779. if layer.name in saved_h5_model_layers_name:
  780. # Get the H5 layer object from its name
  781. h5_layer_object = sharded_checkpoint_file[layer.name]
  782. # Get all the weights as a list from the layer object
  783. symbolic_weights = layer.trainable_weights + layer.non_trainable_weights
  784. saved_weights = {}
  785. # Create a dict from the H5 saved model that looks like {"weight_name": weight_value}
  786. # And a set with only the names
  787. for weight_name in load_attributes_from_hdf5_group(h5_layer_object, "weight_names"):
  788. # TF names always start with the model name so we ignore it
  789. name = "/".join(weight_name.split("/")[1:])
  790. if _prefix is not None:
  791. name = _prefix + "/" + name
  792. saved_weights[name] = np.asarray(h5_layer_object[weight_name])
  793. # Add the updated name to the final list for computing missing/unexpected values
  794. saved_weight_names_set.add(name)
  795. # Loop over each weights from the instantiated model and compare with the weights from the H5 file
  796. for symbolic_weight in symbolic_weights:
  797. # TF names always start with the model name so we ignore it
  798. if _prefix is not None:
  799. delimiter = len(_prefix.split("/"))
  800. symbolic_weight_name = "/".join(
  801. symbolic_weight.name.split("/")[:delimiter]
  802. + symbolic_weight.name.split("/")[delimiter + 1 :]
  803. )
  804. else:
  805. symbolic_weight_name = "/".join(symbolic_weight.name.split("/")[1:])
  806. # here we check if the current weight is among the weights from the H5 file
  807. # If yes, get the weight_value of the corresponding weight from the H5 file
  808. # If not, make the value to None
  809. saved_weight_value = saved_weights.get(symbolic_weight_name)
  810. # Retrocompatibility patch: some embeddings are stored with the weights name (e.g. Bart's
  811. # `model.shared/embeddings:0` are stored as `model.shared/weights:0`)
  812. if saved_weight_value is None and symbolic_weight_name.endswith("embeddings:0"):
  813. symbolic_weight_name = symbolic_weight_name[:-12] + "weight:0"
  814. saved_weight_value = saved_weights.get(symbolic_weight_name)
  815. # Add the updated name to the final list for computing missing/unexpected values
  816. symbolic_weights_names.add(symbolic_weight_name)
  817. # If the current weight is found
  818. if saved_weight_value is not None:
  819. # Check if the shape of the current weight and the one from the H5 file are different
  820. if K.int_shape(symbolic_weight) != saved_weight_value.shape:
  821. # If yes we reshape the weight from the H5 file accordingly to the current weight
  822. # If the two shapes are not compatible we raise an issue
  823. try:
  824. array = np.reshape(saved_weight_value, K.int_shape(symbolic_weight))
  825. except ValueError as e:
  826. if ignore_mismatched_sizes:
  827. mismatched_layers.append(
  828. (symbolic_weight_name, saved_weight_value.shape, K.int_shape(symbolic_weight))
  829. )
  830. continue
  831. else:
  832. raise e
  833. else:
  834. array = saved_weight_value
  835. # We create the tuple that will be loaded and add it to the final list
  836. weight_value_tuples.append((symbolic_weight, array))
  837. # Load all the weights
  838. K.batch_set_value(weight_value_tuples)
  839. # Compute the missing and unexpected layers
  840. missing_layers.extend(list(symbolic_weights_names - saved_weight_names_set))
  841. unexpected_layers.extend(list(saved_weight_names_set - symbolic_weights_names))
  842. return missing_layers, unexpected_layers, mismatched_layers
  843. def load_tf_weights_from_safetensors(model, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None):
  844. # Read the safetensors file
  845. with safe_open(resolved_archive_file, framework="tf") as safetensors_archive:
  846. mismatched_layers = []
  847. weight_names = [strip_model_name_and_prefix(w.name, _prefix=_prefix) for w in model.weights]
  848. loaded_weight_names = list(safetensors_archive.keys())
  849. # Find the missing layers from the high level list of layers
  850. missing_layers = list(set(weight_names) - set(loaded_weight_names))
  851. # Find the unexpected layers from the high level list of layers
  852. unexpected_layers = list(set(loaded_weight_names) - set(weight_names))
  853. for weight in model.weights:
  854. weight_name = strip_model_name_and_prefix(weight.name, _prefix=_prefix)
  855. if weight_name in loaded_weight_names:
  856. weight_value = safetensors_archive.get_tensor(weight_name)
  857. # Check if the shape of the current weight and the one from the H5 file are different
  858. if K.int_shape(weight) != weight_value.shape:
  859. # If yes we reshape the weight from the H5 file accordingly to the current weight
  860. # If the two shapes are not compatible we raise an issue
  861. try:
  862. weight_value = tf.reshape(weight_value, K.int_shape(weight))
  863. except (ValueError, tf.errors.InvalidArgumentError) as e:
  864. if ignore_mismatched_sizes:
  865. mismatched_layers.append((weight_name, weight_value.shape, K.int_shape(weight)))
  866. continue
  867. else:
  868. raise e
  869. K.set_value(weight, weight_value) # weight.assign() might break if weight is a DTensor
  870. return missing_layers, unexpected_layers, mismatched_layers
  871. def init_copy_embeddings(old_embeddings, new_num_tokens):
  872. r"""
  873. This function aims to reduce the embeddings in case new_num_tokens < old_num_tokens or to pad with -1 in case
  874. new_num_tokens > old_num_tokens. A mask is also computed in order to know which weight in the embeddings should be
  875. kept or not. Example:
  876. - if new_num_tokens=5 and old_num_tokens=4 and old_embeddings=[w1,w2,w3,w4]
  877. - mask=[True,True,True,True,False] and current_weights=[w1,w2,w3,w4,-1]
  878. - if new_num_tokens=4 and old_num_tokens=5 and old_embeddings=[w1,w2,w3,w4,w5]
  879. - mask=[True,True,True,True] and current_weights=[w1,w2,w3,w4]
  880. """
  881. old_num_tokens, old_embedding_dim = shape_list(old_embeddings)
  882. size_diff = new_num_tokens - old_num_tokens
  883. # initialize new embeddings
  884. # Copy token embeddings from the previous ones
  885. if tf.math.greater(size_diff, 0):
  886. # if the new size is greater than the old one, we extend the current embeddings with a padding until getting new size
  887. # and we create a mask to properly identify the padded values and be replaced by the values of the newly created
  888. # embeddings
  889. current_weights = tf.pad(
  890. old_embeddings.value(), tf.convert_to_tensor([[0, size_diff], [0, 0]]), constant_values=-1
  891. )
  892. num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
  893. mask = tf.fill(tf.convert_to_tensor([num_tokens_to_copy, 1]), True)
  894. mask = tf.pad(mask, tf.convert_to_tensor([[0, size_diff], [0, 0]]), constant_values=False)
  895. else:
  896. # if the new size if lower than the old one, we take the current embeddings until the new size
  897. current_weights = tf.slice(
  898. old_embeddings.value(),
  899. tf.convert_to_tensor([0, 0]),
  900. tf.convert_to_tensor([new_num_tokens, old_embedding_dim]),
  901. )
  902. mask = tf.fill(tf.convert_to_tensor([new_num_tokens, 1]), True)
  903. return mask, current_weights
  904. class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushToHubMixin):
  905. r"""
  906. Base class for all TF models.
  907. [`TFPreTrainedModel`] takes care of storing the configuration of the models and handles methods for loading,
  908. downloading and saving models as well as a few methods common to all models to:
  909. - resize the input embeddings,
  910. - prune heads in the self-attention heads.
  911. Class attributes (overridden by derived classes):
  912. - **config_class** ([`PretrainedConfig`]) -- A subclass of [`PretrainedConfig`] to use as configuration class
  913. for this model architecture.
  914. - **base_model_prefix** (`str`) -- A string indicating the attribute associated to the base model in derived
  915. classes of the same architecture adding modules on top of the base model.
  916. - **main_input_name** (`str`) -- The name of the principal input to the model (often `input_ids` for NLP
  917. models, `pixel_values` for vision models and `input_values` for speech models).
  918. """
  919. config_class = None
  920. base_model_prefix = ""
  921. main_input_name = "input_ids"
  922. _auto_class = None
  923. _using_dummy_loss = None
  924. _label_to_output_map = None
  925. # a list of re pattern of tensor names to ignore from the model when loading the model weights
  926. # (and avoid unnecessary warnings).
  927. _keys_to_ignore_on_load_missing = None
  928. # a list of re pattern of tensor names to ignore from the weights when loading the model weights
  929. # (and avoid unnecessary warnings).
  930. _keys_to_ignore_on_load_unexpected = None
  931. _requires_load_weight_prefix = False
  932. @property
  933. def dummy_inputs(self) -> dict[str, tf.Tensor]:
  934. """
  935. Dummy inputs to build the network.
  936. Returns:
  937. `dict[str, tf.Tensor]`: The dummy inputs.
  938. """
  939. dummies = {}
  940. for key, spec in self.input_signature.items():
  941. # 2 is the most correct arbitrary size. I will not be taking questions
  942. dummy_shape = [dim if dim is not None else 2 for dim in spec.shape]
  943. if spec.shape[0] is None:
  944. # But let's make the batch size 1 to save memory anyway
  945. dummy_shape[0] = 1
  946. dummies[key] = tf.ones(shape=dummy_shape, dtype=spec.dtype)
  947. if key == "token_type_ids":
  948. # Some models have token_type_ids but with a vocab_size of 1
  949. dummies[key] = tf.zeros_like(dummies[key])
  950. if self.config.add_cross_attention and "encoder_hidden_states" in inspect.signature(self.call).parameters:
  951. if "encoder_hidden_states" not in dummies:
  952. if self.main_input_name == "input_ids":
  953. dummies["encoder_hidden_states"] = tf.ones(
  954. shape=(1, 2, self.config.hidden_size), dtype=tf.float32, name="encoder_hidden_states"
  955. )
  956. else:
  957. raise NotImplementedError(
  958. "Model has cross-attention but we couldn't infer the shape for the encoder hidden states. Please manually override dummy_inputs!"
  959. )
  960. return dummies
  961. def build_in_name_scope(self):
  962. with tf.name_scope(self.name):
  963. self.build(input_shape=None)
  964. @property
  965. def framework(self) -> str:
  966. """
  967. :str: Identifies that this is a TensorFlow model.
  968. """
  969. return "tf"
  970. def build(self, input_shape=None):
  971. pass # This is just here to make sure we don't call the superclass build()
  972. def __init__(self, config, *inputs, **kwargs):
  973. super().__init__(*inputs, **kwargs)
  974. if not isinstance(config, PretrainedConfig):
  975. raise TypeError(
  976. f"Parameter config in `{self.__class__.__name__}(config)` should be an instance of class "
  977. "`PretrainedConfig`. To create a model from a pretrained model use "
  978. f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`"
  979. )
  980. # Save config and origin of the pretrained weights if given in model
  981. self.config = config
  982. self.name_or_path = config.name_or_path
  983. self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
  984. self._set_save_spec(self.input_signature)
  985. logger.warning_once(
  986. "TensorFlow and JAX classes are deprecated and will be removed in Transformers v5. We "
  987. "recommend migrating to PyTorch classes or pinning your version of Transformers."
  988. )
  989. def get_config(self):
  990. return self.config.to_dict()
  991. @functools.wraps(keras.Model.fit)
  992. def fit(self, *args, **kwargs):
  993. args, kwargs = convert_batch_encoding(*args, **kwargs)
  994. return super().fit(*args, **kwargs)
  995. @functools.wraps(keras.Model.train_on_batch)
  996. def train_on_batch(self, *args, **kwargs):
  997. args, kwargs = convert_batch_encoding(*args, **kwargs)
  998. return super().train_on_batch(*args, **kwargs)
  999. @functools.wraps(keras.Model.test_on_batch)
  1000. def test_on_batch(self, *args, **kwargs):
  1001. args, kwargs = convert_batch_encoding(*args, **kwargs)
  1002. return super().test_on_batch(*args, **kwargs)
  1003. @functools.wraps(keras.Model.predict_on_batch)
  1004. def predict_on_batch(self, *args, **kwargs):
  1005. args, kwargs = convert_batch_encoding(*args, **kwargs)
  1006. return super().predict_on_batch(*args, **kwargs)
  1007. @functools.wraps(keras.Model.predict)
  1008. def predict(self, *args, **kwargs):
  1009. args, kwargs = convert_batch_encoding(*args, **kwargs)
  1010. return super().predict(*args, **kwargs)
  1011. @functools.wraps(keras.Model.evaluate)
  1012. def evaluate(self, *args, **kwargs):
  1013. args, kwargs = convert_batch_encoding(*args, **kwargs)
  1014. return super().evaluate(*args, **kwargs)
  1015. @classmethod
  1016. def from_config(cls, config, **kwargs):
  1017. if isinstance(config, PretrainedConfig):
  1018. return cls._from_config(config, **kwargs)
  1019. return cls._from_config(cls.config_class.from_dict(config, **kwargs))
  1020. @classmethod
  1021. def _from_config(cls, config, **kwargs):
  1022. """
  1023. All context managers that the model should be initialized under go here.
  1024. """
  1025. return cls(config, **kwargs)
  1026. def get_head_mask(self, head_mask: tf.Tensor | None, num_hidden_layers: int) -> tf.Tensor:
  1027. """
  1028. Prepare the head mask if needed.
  1029. Args:
  1030. head_mask (`tf.Tensor` with shape `[num_heads]` or `[num_hidden_layers x num_heads]`, *optional*):
  1031. The mask indicating if we should keep the heads or not (1.0 for keep, 0.0 for discard).
  1032. num_hidden_layers (`int`):
  1033. The number of hidden layers in the model.
  1034. Returns:
  1035. `tf.Tensor` with shape `[num_hidden_layers x batch x num_heads x seq_length x seq_length]` or list with
  1036. `[None]` for each layer.
  1037. """
  1038. if head_mask is not None:
  1039. head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers)
  1040. else:
  1041. head_mask = [None] * num_hidden_layers
  1042. return head_mask
  1043. def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers):
  1044. """-> [num_hidden_layers x batch x num_heads x seq_length x seq_length]"""
  1045. if head_mask.shape.rank == 1:
  1046. head_mask = head_mask[None, None, :, None, None]
  1047. head_mask = tf.repeat(head_mask, repeats=num_hidden_layers, axis=0)
  1048. elif head_mask.shape.rank == 2:
  1049. head_mask = head_mask[:, None, :, None, None]
  1050. assert head_mask.shape.rank == 5, f"head_mask.dim != 5, instead {head_mask.dim()}"
  1051. head_mask = tf.cast(head_mask, tf.float32) # switch to float if need + fp16 compatibility
  1052. return head_mask
  1053. @tf.function
  1054. def serving(self, inputs):
  1055. """
  1056. Args:
  1057. Method used for serving the model. Does not have a specific signature, but will be specialized as concrete
  1058. functions when saving with `save_pretrained`.
  1059. inputs (`dict[str, tf.Tensor]`):
  1060. The input of the saved model as a dictionary of tensors.
  1061. """
  1062. output = self.call(inputs)
  1063. return self.serving_output(output)
  1064. @property
  1065. def input_signature(self) -> dict[str, tf.TensorSpec]:
  1066. """
  1067. This property should return a dict mapping input names to tf.TensorSpec objects, representing the expected
  1068. shape and dtype for model inputs. It is used for both serving and for generating dummy inputs.
  1069. """
  1070. model_inputs = list(inspect.signature(self.call).parameters)
  1071. sig = {}
  1072. if "input_ids" in model_inputs:
  1073. if self.__class__.__name__.endswith("ForMultipleChoice"):
  1074. text_dims = 3
  1075. else:
  1076. text_dims = 2
  1077. for input_name in (
  1078. "input_ids",
  1079. "attention_mask",
  1080. "token_type_ids",
  1081. "decoder_input_ids",
  1082. "decoder_attention_mask",
  1083. ):
  1084. if input_name in model_inputs:
  1085. sig[input_name] = tf.TensorSpec([None] * text_dims, tf.int32, name=input_name)
  1086. if "pixel_values" in model_inputs:
  1087. pixel_values_shape = [None, None, None, None]
  1088. if hasattr(self.config, "vision_config"):
  1089. vision_config = self.config.vision_config
  1090. else:
  1091. vision_config = self.config
  1092. if hasattr(vision_config, "num_channels"):
  1093. pixel_values_shape[1] = vision_config.num_channels
  1094. else:
  1095. raise NotImplementedError(
  1096. "Could not infer number of channels from config, please override input_signature to specify input shapes."
  1097. )
  1098. if hasattr(vision_config, "image_size"):
  1099. pixel_values_shape[2] = pixel_values_shape[3] = vision_config.image_size
  1100. elif hasattr(vision_config, "input_size"):
  1101. pixel_values_shape[2] = pixel_values_shape[3] = vision_config.input_size
  1102. else:
  1103. raise NotImplementedError(
  1104. "Could not infer input image shape from config, please override input_signature to specify input shapes."
  1105. )
  1106. sig["pixel_values"] = tf.TensorSpec(pixel_values_shape, tf.float32, name="pixel_values")
  1107. if "input_features" in model_inputs:
  1108. raise NotImplementedError("Audio models need a manually defined input_signature")
  1109. return sig
  1110. def serving_output(self, output):
  1111. """
  1112. Prepare the output of the saved model. Can be overridden if specific serving modifications are required.
  1113. """
  1114. if not isinstance(output, ModelOutput):
  1115. return output
  1116. for key in output:
  1117. if key.endswith("hidden_states") and not getattr(self.config, "output_hidden_states", False):
  1118. output[key] = None
  1119. elif key.endswith("attentions") and not getattr(self.config, "output_attentions", False):
  1120. output[key] = None
  1121. elif key == "past_key_values" and not getattr(self.config, "use_cache", False):
  1122. output[key] = None
  1123. elif key == "cross_attentions" and not (
  1124. getattr(self.config, "output_attentions", False) and getattr(self.config, "add_cross_attention", False)
  1125. ):
  1126. output[key] = None
  1127. if isinstance(output[key], (tuple, list)):
  1128. try:
  1129. output[key] = tf.convert_to_tensor(output[key])
  1130. except (ValueError, tf.errors.InvalidArgumentError):
  1131. pass # Layers may not have the same dimensions
  1132. return output
  1133. @classmethod
  1134. def can_generate(cls) -> bool:
  1135. """
  1136. Returns whether this model can generate sequences with `.generate()`.
  1137. Returns:
  1138. `bool`: Whether this model can generate sequences with `.generate()`.
  1139. """
  1140. # Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation.
  1141. # Alternatively, the model can also have a custom `generate` function.
  1142. if "GenerationMixin" in str(cls.prepare_inputs_for_generation) and "GenerationMixin" in str(cls.generate):
  1143. return False
  1144. return True
  1145. def get_input_embeddings(self) -> keras.layers.Layer:
  1146. """
  1147. Returns the model's input embeddings layer.
  1148. Returns:
  1149. `tf.Variable`: The embeddings layer mapping vocabulary to hidden states.
  1150. """
  1151. main_layer = getattr(self, self.base_model_prefix, self)
  1152. if main_layer is not self:
  1153. return main_layer.get_input_embeddings()
  1154. else:
  1155. raise NotImplementedError
  1156. def _save_checkpoint(self, checkpoint_dir, epoch):
  1157. if not os.path.isdir(checkpoint_dir):
  1158. os.mkdir(checkpoint_dir)
  1159. # We avoid tf.train.checkpoint or saving weights in TF format, even though that includes optimizer
  1160. # state for us, because it requires special handling for objects like custom losses, which we use
  1161. # internally and which users are likely to use too
  1162. weights_path = os.path.join(checkpoint_dir, "weights.h5")
  1163. self.save_weights(weights_path)
  1164. extra_data = {"epoch": epoch, "optimizer_state": self.optimizer.get_weights()}
  1165. extra_data_path = os.path.join(checkpoint_dir, "extra_data.pickle")
  1166. with open(extra_data_path, "wb") as f:
  1167. pickle.dump(extra_data, f)
  1168. def prepare_tf_dataset(
  1169. self,
  1170. dataset: datasets.Dataset, # noqa:F821
  1171. batch_size: int = 8,
  1172. shuffle: bool = True,
  1173. tokenizer: PreTrainedTokenizerBase | None = None,
  1174. collate_fn: Callable | None = None,
  1175. collate_fn_args: dict[str, Any] | None = None,
  1176. drop_remainder: bool | None = None,
  1177. prefetch: bool = True,
  1178. ):
  1179. """
  1180. Wraps a HuggingFace [`~datasets.Dataset`] as a `tf.data.Dataset` with collation and batching. This method is
  1181. designed to create a "ready-to-use" dataset that can be passed directly to Keras methods like `fit()` without
  1182. further modification. The method will drop columns from the dataset if they don't match input names for the
  1183. model. If you want to specify the column names to return rather than using the names that match this model, we
  1184. recommend using `Dataset.to_tf_dataset()` instead.
  1185. Args:
  1186. dataset (`Any`):
  1187. A [~`datasets.Dataset`] to be wrapped as a `tf.data.Dataset`.
  1188. batch_size (`int`, *optional*, defaults to 8):
  1189. The size of batches to return.
  1190. shuffle (`bool`, defaults to `True`):
  1191. Whether to return samples from the dataset in random order. Usually `True` for training datasets and
  1192. `False` for validation/test datasets.
  1193. tokenizer ([`PreTrainedTokenizerBase`], *optional*):
  1194. A `PreTrainedTokenizer` that will be used to pad samples to create batches. Has no effect if a specific
  1195. `collate_fn` is passed instead.
  1196. collate_fn (`Callable`, *optional*):
  1197. A function that collates samples from the dataset into a single batch. Defaults to
  1198. `DefaultDataCollator` if no `tokenizer` is supplied or `DataCollatorWithPadding` if a `tokenizer` is
  1199. passed.
  1200. collate_fn_args (`dict[str, Any]`, *optional*):
  1201. A dict of arguments to pass to the `collate_fn` alongside the list of samples.
  1202. drop_remainder (`bool`, *optional*):
  1203. Whether to drop the final batch, if the batch_size does not evenly divide the dataset length. Defaults
  1204. to the same setting as `shuffle`.
  1205. prefetch (`bool`, defaults to `True`):
  1206. Whether to add prefetching to the end of the `tf.data` pipeline. This is almost always beneficial for
  1207. performance, but can be disabled in edge cases.
  1208. Returns:
  1209. `Dataset`: A `tf.data.Dataset` which is ready to pass to the Keras API.
  1210. """
  1211. requires_backends(self, ["datasets"])
  1212. import datasets
  1213. if collate_fn is None:
  1214. if tokenizer is None:
  1215. collate_fn = DefaultDataCollator(return_tensors="np")
  1216. else:
  1217. collate_fn = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="np")
  1218. if collate_fn_args is None:
  1219. collate_fn_args = {}
  1220. if not isinstance(dataset, datasets.Dataset):
  1221. raise TypeError("Dataset argument should be a datasets.Dataset!")
  1222. model_inputs = list(inspect.signature(self.call).parameters)
  1223. model_labels = find_labels(self.__class__)
  1224. if "cols_to_retain" in list(inspect.signature(dataset._get_output_signature).parameters.keys()):
  1225. output_signature, _ = dataset._get_output_signature(
  1226. dataset,
  1227. batch_size=None,
  1228. collate_fn=collate_fn,
  1229. collate_fn_args=collate_fn_args,
  1230. cols_to_retain=model_inputs,
  1231. )
  1232. else:
  1233. # TODO Matt: This is a workaround for older versions of datasets that are missing the `cols_to_retain`
  1234. # argument. We should remove this once the minimum supported version of datasets is > 2.3.2
  1235. unwanted_columns = [
  1236. feature
  1237. for feature in dataset.features
  1238. if feature not in model_inputs and feature not in ("label_ids", "label")
  1239. ]
  1240. dataset = dataset.remove_columns(unwanted_columns)
  1241. output_signature, _ = dataset._get_output_signature(
  1242. dataset, batch_size=None, collate_fn=collate_fn, collate_fn_args=collate_fn_args
  1243. )
  1244. output_columns = list(output_signature.keys())
  1245. feature_cols = [col for col in output_columns if col in model_inputs and col not in model_labels]
  1246. label_cols = [col for col in output_columns if col in model_labels]
  1247. # Backwards compatibility for older versions of datasets. Previously, if `columns` or `label_cols`
  1248. # were a single element list, the returned element spec would be a single element. Now, passing [feature]
  1249. # will return a dict structure {"feature": feature}, and passing a single string will return a single element.
  1250. feature_cols = feature_cols[0] if len(feature_cols) == 1 else feature_cols
  1251. label_cols = label_cols[0] if len(label_cols) == 1 else label_cols
  1252. if drop_remainder is None:
  1253. drop_remainder = shuffle
  1254. tf_dataset = dataset.to_tf_dataset(
  1255. columns=feature_cols,
  1256. label_cols=label_cols,
  1257. batch_size=batch_size,
  1258. shuffle=shuffle,
  1259. drop_remainder=drop_remainder,
  1260. collate_fn=collate_fn,
  1261. collate_fn_args=collate_fn_args,
  1262. prefetch=prefetch,
  1263. )
  1264. return tf_dataset
  1265. def compile(
  1266. self,
  1267. optimizer="rmsprop",
  1268. loss="auto_with_warning",
  1269. metrics=None,
  1270. loss_weights=None,
  1271. weighted_metrics=None,
  1272. run_eagerly=None,
  1273. steps_per_execution=None,
  1274. **kwargs,
  1275. ):
  1276. """
  1277. This is a thin wrapper that sets the model's loss output head as the loss if the user does not specify a loss
  1278. function themselves.
  1279. """
  1280. if loss in ("auto_with_warning", "passthrough"): # "passthrough" for workflow backward compatibility
  1281. logger.info(
  1282. "No loss specified in compile() - the model's internal loss computation will be used as the "
  1283. "loss. Don't panic - this is a common way to train TensorFlow models in Transformers! "
  1284. "To disable this behaviour please pass a loss argument, or explicitly pass "
  1285. "`loss=None` if you do not want your model to compute a loss. You can also specify `loss='auto'` to "
  1286. "get the internal loss without printing this info string."
  1287. )
  1288. loss = "auto"
  1289. if loss == "auto":
  1290. loss = dummy_loss
  1291. self._using_dummy_loss = True
  1292. else:
  1293. self._using_dummy_loss = False
  1294. parent_args = list(inspect.signature(keras.Model.compile).parameters.keys())
  1295. # This argument got renamed, we need to support both versions
  1296. if "steps_per_execution" in parent_args:
  1297. super().compile(
  1298. optimizer=optimizer,
  1299. loss=loss,
  1300. metrics=metrics,
  1301. loss_weights=loss_weights,
  1302. weighted_metrics=weighted_metrics,
  1303. run_eagerly=run_eagerly,
  1304. steps_per_execution=steps_per_execution,
  1305. **kwargs,
  1306. )
  1307. else:
  1308. super().compile(
  1309. optimizer=optimizer,
  1310. loss=loss,
  1311. metrics=metrics,
  1312. loss_weights=loss_weights,
  1313. weighted_metrics=weighted_metrics,
  1314. run_eagerly=run_eagerly,
  1315. experimental_steps_per_execution=steps_per_execution,
  1316. **kwargs,
  1317. )
  1318. def compute_loss(self, *args, **kwargs):
  1319. if hasattr(keras.Model, "compute_loss"):
  1320. # This will be true in TF 2.8 or greater
  1321. return super().compute_loss(*args, **kwargs)
  1322. else:
  1323. warnings.warn(
  1324. "The old compute_loss method is deprecated as it conflicts with the Keras compute_loss "
  1325. "method added in TF 2.8. If you want the original HF compute_loss, please call "
  1326. "hf_compute_loss() instead. From TF versions >= 2.8, or Transformers versions >= 5, "
  1327. "calling compute_loss() will get the Keras method instead.",
  1328. FutureWarning,
  1329. )
  1330. return self.hf_compute_loss(*args, **kwargs)
  1331. def get_label_to_output_name_mapping(self):
  1332. arg_names = list(inspect.signature(self.call).parameters)
  1333. if self._label_to_output_map is not None:
  1334. return self._label_to_output_map
  1335. elif "start_positions" in arg_names:
  1336. return {"start_positions": "start_logits", "end_positions": "end_logits"}
  1337. elif "sentence_order_label" in arg_names:
  1338. return {"labels": "prediction_logits", "sentence_order_label": "sop_logits"}
  1339. elif "next_sentence_label" in arg_names:
  1340. return {"labels": "prediction_logits", "next_sentence_label": "seq_relationship_logits"}
  1341. elif "mc_labels" in arg_names:
  1342. return {"labels": "logits", "mc_labels": "mc_logits"}
  1343. else:
  1344. return {}
  1345. def train_step(self, data):
  1346. """
  1347. A modification of Keras's default `train_step` that correctly handles matching outputs to labels for our models
  1348. and supports directly training on the loss output head. In addition, it ensures input keys are copied to the
  1349. labels where appropriate. It will also copy label keys into the input dict when using the dummy loss, to ensure
  1350. that they are available to the model during the forward pass.
  1351. """
  1352. # We hardcode the most common renamings; models with weirder names can set `self._label_to_output_map`
  1353. arg_names = list(inspect.signature(self.call).parameters)
  1354. label_kwargs = find_labels(self.__class__)
  1355. label_to_output = self.get_label_to_output_name_mapping()
  1356. output_to_label = {val: key for key, val in label_to_output.items()}
  1357. if not self._using_dummy_loss and parse(tf.__version__) < parse("2.11.0"):
  1358. # Newer TF train steps leave this out
  1359. data = expand_1d(data)
  1360. x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(data)
  1361. # If the inputs are mutable dictionaries, make a shallow copy of them because we will modify
  1362. # them during input/label pre-processing. This avoids surprising the user by wrecking their data.
  1363. # In addition, modifying mutable Python inputs makes XLA compilation impossible.
  1364. if isinstance(x, dict):
  1365. x = x.copy()
  1366. if isinstance(y, dict):
  1367. y = y.copy()
  1368. # When using a dummy loss, we ensure that separate labels are copied to the correct model arguments,
  1369. # if those keys are not already present in the input dict
  1370. if self._using_dummy_loss and y is not None:
  1371. # If y is a tensor and the model only has one label-like input, map y to that input
  1372. if len(label_kwargs) == 1 and isinstance(y, tf.Tensor):
  1373. if isinstance(x, tf.Tensor):
  1374. x = {arg_names[0]: x}
  1375. label_kwarg = next(iter(label_kwargs))
  1376. if label_kwarg not in x:
  1377. x[label_kwarg] = y
  1378. # Otherwise, copy keys from y to x as long as they weren't already present in x
  1379. elif isinstance(y, dict):
  1380. if isinstance(x, tf.Tensor):
  1381. x = {arg_names[0]: x}
  1382. for key, val in y.items():
  1383. if key in arg_names and key not in x:
  1384. x[key] = val
  1385. elif output_to_label.get(key) in arg_names and key not in x:
  1386. x[output_to_label[key]] = val
  1387. if y is None:
  1388. y = {key: val for key, val in x.items() if key in label_kwargs}
  1389. if not y and not self._using_dummy_loss:
  1390. raise ValueError("Could not find label column(s) in input dict and no separate labels were provided!")
  1391. if isinstance(y, dict):
  1392. # Rename labels at this point to match output heads
  1393. y = {label_to_output.get(key, key): val for key, val in y.items()}
  1394. # Run forward pass.
  1395. with tf.GradientTape() as tape:
  1396. if self._using_dummy_loss and "return_loss" in arg_names:
  1397. y_pred = self(x, training=True, return_loss=True)
  1398. else:
  1399. y_pred = self(x, training=True)
  1400. if self._using_dummy_loss:
  1401. loss = self.compiled_loss(y_pred.loss, y_pred.loss, sample_weight, regularization_losses=self.losses)
  1402. else:
  1403. loss = None
  1404. # This next block matches outputs to label keys. Tensorflow's standard method for doing this
  1405. # can get very confused if any of the keys contain nested values (e.g. lists/tuples of Tensors)
  1406. if isinstance(y, dict) and len(y) == 1:
  1407. if list(y.keys())[0] in y_pred:
  1408. y_pred = y_pred[list(y.keys())[0]]
  1409. elif list(y_pred.keys())[0] == "loss":
  1410. y_pred = y_pred[1]
  1411. else:
  1412. y_pred = y_pred[0]
  1413. _, y = y.popitem()
  1414. elif isinstance(y, dict):
  1415. # If the labels are a dict, match keys from the output by name
  1416. y_pred = {key: val for key, val in y_pred.items() if key in y}
  1417. elif isinstance(y, (tuple, list)):
  1418. # If the labels are a tuple/list, match keys to the output by order, skipping the loss.
  1419. if list(y_pred.keys())[0] == "loss":
  1420. y_pred = y_pred.to_tuple()[1:]
  1421. else:
  1422. y_pred = y_pred.to_tuple()
  1423. y_pred = y_pred[: len(y)] # Remove unused fields in case those cause problems
  1424. else:
  1425. # If the labels are a single tensor, match them to the first non-loss tensor in the output
  1426. if list(y_pred.keys())[0] == "loss":
  1427. y_pred = y_pred[1]
  1428. else:
  1429. y_pred = y_pred[0]
  1430. if loss is None:
  1431. loss = self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)
  1432. # Run backwards pass.
  1433. self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
  1434. self.compiled_metrics.update_state(y, y_pred, sample_weight)
  1435. # Collect metrics to return
  1436. return_metrics = {}
  1437. for metric in self.metrics:
  1438. result = metric.result()
  1439. if isinstance(result, dict):
  1440. return_metrics.update(result)
  1441. else:
  1442. return_metrics[metric.name] = result
  1443. return return_metrics
  1444. def test_step(self, data):
  1445. """
  1446. A modification of Keras's default `train_step` that correctly handles matching outputs to labels for our models
  1447. and supports directly training on the loss output head. In addition, it ensures input keys are copied to the
  1448. labels where appropriate. It will also copy label keys into the input dict when using the dummy loss, to ensure
  1449. that they are available to the model during the forward pass.
  1450. """
  1451. # We hardcode the most common renamings; models with weirder names can set `self._label_to_output_map`
  1452. arg_names = list(inspect.signature(self.call).parameters)
  1453. label_kwargs = find_labels(self.__class__)
  1454. label_to_output = self.get_label_to_output_name_mapping()
  1455. output_to_label = {val: key for key, val in label_to_output.items()}
  1456. if not self._using_dummy_loss and parse(tf.__version__) < parse("2.11.0"):
  1457. # Newer versions leave this out
  1458. data = expand_1d(data)
  1459. x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(data)
  1460. # If the inputs are mutable dictionaries, make a shallow copy of them because we will modify
  1461. # them during input/label pre-processing. This avoids surprising the user by wrecking their data.
  1462. # In addition, modifying mutable Python inputs makes XLA compilation impossible.
  1463. if isinstance(x, dict):
  1464. x = x.copy()
  1465. if isinstance(y, dict):
  1466. y = y.copy()
  1467. # When using a dummy loss, we ensure that separate labels are copied to the correct model arguments,
  1468. # if those keys are not already present in the input dict
  1469. if self._using_dummy_loss and y is not None:
  1470. arg_names = list(inspect.signature(self.call).parameters)
  1471. # If y is a tensor and the model only has one label-like input, map y to that input
  1472. if len(label_kwargs) == 1 and isinstance(y, tf.Tensor):
  1473. if isinstance(x, tf.Tensor):
  1474. x = {arg_names[0]: x}
  1475. label_kwarg = next(iter(label_kwargs))
  1476. if label_kwarg not in x:
  1477. x[label_kwarg] = y
  1478. # Otherwise, copy keys from y to x as long as they weren't already present in x
  1479. elif isinstance(y, dict):
  1480. if isinstance(x, tf.Tensor):
  1481. x = {arg_names[0]: x}
  1482. for key, val in y.items():
  1483. if key in arg_names and key not in x:
  1484. x[key] = val
  1485. elif output_to_label.get(key) in arg_names and key not in x:
  1486. x[output_to_label[key]] = val
  1487. if y is None:
  1488. y = {key: val for key, val in x.items() if key in label_kwargs}
  1489. if not y and not self._using_dummy_loss:
  1490. raise ValueError("Could not find label column(s) in input dict and no separate labels were provided!")
  1491. if isinstance(y, dict):
  1492. # Rename labels at this point to match output heads
  1493. y = {label_to_output.get(key, key): val for key, val in y.items()}
  1494. # Run forward pass.
  1495. if self._using_dummy_loss and "return_loss" in arg_names:
  1496. y_pred = self(x, return_loss=True, training=False)
  1497. else:
  1498. y_pred = self(x, training=False)
  1499. if self._using_dummy_loss:
  1500. loss = self.compiled_loss(y_pred.loss, y_pred.loss, sample_weight, regularization_losses=self.losses)
  1501. else:
  1502. loss = None
  1503. # This next block matches outputs to label keys. Tensorflow's standard method for doing this
  1504. # can get very confused if any of the keys contain nested values (e.g. lists/tuples of Tensors)
  1505. if isinstance(y, dict) and len(y) == 1:
  1506. if list(y.keys())[0] in y_pred:
  1507. y_pred = y_pred[list(y.keys())[0]]
  1508. elif list(y_pred.keys())[0] == "loss":
  1509. y_pred = y_pred[1]
  1510. else:
  1511. y_pred = y_pred[0]
  1512. _, y = y.popitem()
  1513. elif isinstance(y, dict):
  1514. # If the labels are a dict, match keys from the output by name
  1515. y_pred = {key: val for key, val in y_pred.items() if key in y}
  1516. elif isinstance(y, (tuple, list)):
  1517. # If the labels are a tuple/list, match keys to the output by order, skipping the loss.
  1518. if list(y_pred.keys())[0] == "loss":
  1519. y_pred = y_pred.to_tuple()[1:]
  1520. else:
  1521. y_pred = y_pred.to_tuple()
  1522. y_pred = y_pred[: len(y)] # Remove unused fields in case those cause problems
  1523. else:
  1524. # If the labels are a single tensor, match them to the first non-loss tensor in the output
  1525. if list(y_pred.keys())[0] == "loss":
  1526. y_pred = y_pred[1]
  1527. else:
  1528. y_pred = y_pred[0]
  1529. if loss is None:
  1530. loss = self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)
  1531. self.compiled_metrics.update_state(y, y_pred, sample_weight)
  1532. # Collect metrics to return
  1533. return_metrics = {}
  1534. for metric in self.metrics:
  1535. result = metric.result()
  1536. if isinstance(result, dict):
  1537. return_metrics.update(result)
  1538. else:
  1539. return_metrics[metric.name] = result
  1540. return return_metrics
  1541. def create_model_card(
  1542. self,
  1543. output_dir,
  1544. model_name: str,
  1545. language: str | None = None,
  1546. license: str | None = None,
  1547. tags: str | None = None,
  1548. finetuned_from: str | None = None,
  1549. tasks: str | None = None,
  1550. dataset_tags: str | list[str] | None = None,
  1551. dataset: str | list[str] | None = None,
  1552. dataset_args: str | list[str] | None = None,
  1553. ):
  1554. """
  1555. Creates a draft of a model card using the information available to the `Trainer`.
  1556. Args:
  1557. output_dir (`str` or `os.PathLike`):
  1558. The folder in which to create the model card.
  1559. model_name (`str`, *optional*):
  1560. The name of the model.
  1561. language (`str`, *optional*):
  1562. The language of the model (if applicable)
  1563. license (`str`, *optional*):
  1564. The license of the model. Will default to the license of the pretrained model used, if the original
  1565. model given to the `Trainer` comes from a repo on the Hub.
  1566. tags (`str` or `list[str]`, *optional*):
  1567. Some tags to be included in the metadata of the model card.
  1568. finetuned_from (`str`, *optional*):
  1569. The name of the model used to fine-tune this one (if applicable). Will default to the name of the repo
  1570. of the original model given to the `Trainer` (if it comes from the Hub).
  1571. tasks (`str` or `list[str]`, *optional*):
  1572. One or several task identifiers, to be included in the metadata of the model card.
  1573. dataset_tags (`str` or `list[str]`, *optional*):
  1574. One or several dataset tags, to be included in the metadata of the model card.
  1575. dataset (`str` or `list[str]`, *optional*):
  1576. One or several dataset identifiers, to be included in the metadata of the model card.
  1577. dataset_args (`str` or `list[str]`, *optional*):
  1578. One or several dataset arguments, to be included in the metadata of the model card.
  1579. """
  1580. # Avoids a circular import by doing this when necessary.
  1581. from .modelcard import TrainingSummary # tests_ignore
  1582. training_summary = TrainingSummary.from_keras(
  1583. self,
  1584. keras_history=self.history,
  1585. language=language,
  1586. license=license,
  1587. tags=tags,
  1588. model_name=model_name,
  1589. finetuned_from=finetuned_from,
  1590. tasks=tasks,
  1591. dataset_tags=dataset_tags,
  1592. dataset=dataset,
  1593. dataset_args=dataset_args,
  1594. )
  1595. model_card = training_summary.to_model_card()
  1596. with open(os.path.join(output_dir, "README.md"), "w") as f:
  1597. f.write(model_card)
  1598. def set_input_embeddings(self, value):
  1599. """
  1600. Set model's input embeddings
  1601. Args:
  1602. value (`tf.Variable`):
  1603. The new weights mapping hidden states to vocabulary.
  1604. """
  1605. main_layer = getattr(self, self.base_model_prefix)
  1606. if main_layer is None:
  1607. raise NotImplementedError("The model does not implements the base_model_prefix attribute.")
  1608. try:
  1609. main_layer.set_input_embeddings(value)
  1610. except AttributeError:
  1611. logger.info("Building the model")
  1612. self.build_in_name_scope()
  1613. main_layer.set_input_embeddings(value)
  1614. def get_output_embeddings(self) -> None | keras.layers.Layer:
  1615. """
  1616. Returns the model's output embeddings
  1617. Returns:
  1618. `tf.Variable`: The new weights mapping vocabulary to hidden states.
  1619. """
  1620. if self.get_lm_head() is not None:
  1621. lm_head = self.get_lm_head()
  1622. try:
  1623. return lm_head.get_output_embeddings()
  1624. except AttributeError:
  1625. logger.info("Building the model")
  1626. self.build_in_name_scope()
  1627. return lm_head().get_output_embeddings()
  1628. return None # Overwrite for models with output embeddings
  1629. def set_output_embeddings(self, value):
  1630. """
  1631. Set model's output embeddings
  1632. Args:
  1633. value (`tf.Variable`):
  1634. The new weights mapping hidden states to vocabulary.
  1635. """
  1636. if self.get_lm_head() is not None:
  1637. lm_head = self.get_lm_head()
  1638. try:
  1639. lm_head.set_output_embeddings(value)
  1640. except AttributeError:
  1641. logger.info("Building the model")
  1642. self.build_in_name_scope()
  1643. lm_head.set_output_embeddings(value)
  1644. def get_output_layer_with_bias(self) -> None | keras.layers.Layer:
  1645. """
  1646. Get the layer that handles a bias attribute in case the model has an LM head with weights tied to the
  1647. embeddings
  1648. Return:
  1649. `keras.layers.Layer`: The layer that handles the bias, None if not an LM model.
  1650. """
  1651. warnings.warn(
  1652. "The method get_output_layer_with_bias is deprecated. Please use `get_lm_head` instead.", FutureWarning
  1653. )
  1654. return self.get_lm_head()
  1655. def get_prefix_bias_name(self) -> None | str:
  1656. """
  1657. Get the concatenated _prefix name of the bias from the model name to the parent layer
  1658. Return:
  1659. `str`: The _prefix name of the bias.
  1660. """
  1661. warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
  1662. return None
  1663. def get_bias(self) -> None | dict[str, tf.Variable]:
  1664. """
  1665. Dict of bias attached to an LM head. The key represents the name of the bias attribute.
  1666. Return:
  1667. `tf.Variable`: The weights representing the bias, None if not an LM model.
  1668. """
  1669. if self.get_lm_head() is not None:
  1670. lm_head = self.get_lm_head()
  1671. try:
  1672. return lm_head.get_bias()
  1673. except AttributeError:
  1674. self.build_in_name_scope()
  1675. return lm_head.get_bias()
  1676. return None
  1677. def set_bias(self, value):
  1678. """
  1679. Set all the bias in the LM head.
  1680. Args:
  1681. value (`dict[tf.Variable]`):
  1682. All the new bias attached to an LM head.
  1683. """
  1684. if self.get_lm_head() is not None:
  1685. lm_head = self.get_lm_head()
  1686. try:
  1687. lm_head.set_bias(value)
  1688. except AttributeError:
  1689. self.build_in_name_scope()
  1690. lm_head.set_bias(value)
  1691. def get_lm_head(self) -> keras.layers.Layer:
  1692. """
  1693. The LM Head layer. This method must be overwritten by all the models that have a lm head.
  1694. Return:
  1695. `keras.layers.Layer`: The LM head layer if the model has one, None if not.
  1696. """
  1697. return None
  1698. def resize_token_embeddings(self, new_num_tokens: int | None = None) -> keras.layers.Embedding | tf.Variable:
  1699. """
  1700. Resizes input token embeddings matrix of the model if `new_num_tokens != config.vocab_size`.
  1701. Takes care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
  1702. Arguments:
  1703. new_num_tokens (`int`, *optional*):
  1704. The number of new tokens in the embedding matrix. Increasing the size will add newly initialized
  1705. vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just
  1706. returns a pointer to the input tokens without doing anything.
  1707. Return:
  1708. `tf.Variable` or `keras.layers.Embedding`: Pointer to the input tokens of the model.
  1709. """
  1710. # TODO (joao): flagged for replacement (by `_v2_resized_token_embeddings`) due to embeddings refactor
  1711. # Run the new code path if the model has a keras embeddings layer
  1712. if isinstance(self.get_input_embeddings(), keras.layers.Embedding):
  1713. return self._v2_resized_token_embeddings(new_num_tokens)
  1714. if new_num_tokens is None or new_num_tokens == self.config.vocab_size:
  1715. return self._get_word_embedding_weight(self.get_input_embeddings())
  1716. model_embeds = self._resize_token_embeddings(new_num_tokens)
  1717. # Update base model and current model config
  1718. self.config.vocab_size = new_num_tokens
  1719. return model_embeds
  1720. def _v2_resized_token_embeddings(self, new_num_tokens: int | None = None) -> keras.layers.Embedding:
  1721. """
  1722. Resizes input token embeddings matrix of the model if `new_num_tokens != config.vocab_size`.
  1723. Arguments:
  1724. new_num_tokens (`int`, *optional*):
  1725. The number of new tokens in the embedding matrix. Increasing the size will add newly initialized
  1726. vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just
  1727. returns a pointer to the input tokens without doing anything.
  1728. Return:
  1729. `keras.layers.Embedding`: Pointer to the input tokens of the model.
  1730. """
  1731. if new_num_tokens is None or new_num_tokens == self.config.vocab_size:
  1732. return self.get_input_embeddings()
  1733. model_embeds = self._v2_resize_token_embeddings(new_num_tokens)
  1734. # Update base model and current model config
  1735. self.config.vocab_size = new_num_tokens
  1736. return model_embeds
  1737. def _get_word_embedding_weight(model, embedding_layer):
  1738. # TODO (joao): flagged for detection due to embeddings refactor
  1739. # If the variable holds the weights themselves, return them
  1740. if isinstance(embedding_layer, tf.Tensor):
  1741. return embedding_layer
  1742. # Otherwise, try to get them from the layer's attributes
  1743. embeds = getattr(embedding_layer, "weight", None)
  1744. if embeds is not None:
  1745. return embeds
  1746. embeds = getattr(embedding_layer, "decoder", None)
  1747. if embeds is not None:
  1748. return embeds
  1749. # The reason why the attributes don't exist might be
  1750. # because the model is not built, so retry getting
  1751. # the argument after building the model
  1752. model.build_in_name_scope()
  1753. embeds = getattr(embedding_layer, "weight", None)
  1754. if embeds is not None:
  1755. return embeds
  1756. embeds = getattr(embedding_layer, "decoder", None)
  1757. if embeds is not None:
  1758. return embeds
  1759. return None
  1760. def _resize_token_embeddings(self, new_num_tokens):
  1761. # TODO (joao): flagged for replacement (by `_v2_resize_token_embeddings`) due to embeddings refactor
  1762. old_embeddings = self._get_word_embedding_weight(self.get_input_embeddings())
  1763. new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
  1764. # if word embeddings are not tied, make sure that lm head bias is resized as well
  1765. if self.get_bias() is not None:
  1766. old_lm_head_bias = self.get_bias()
  1767. new_lm_head_bias = self._get_resized_lm_head_bias(old_lm_head_bias, new_num_tokens)
  1768. self.set_bias(new_lm_head_bias)
  1769. # if word embeddings are not tied, make sure that lm head decoder is resized as well
  1770. if self.get_output_embeddings() is not None:
  1771. old_lm_head_decoder = self._get_word_embedding_weight(self.get_output_embeddings())
  1772. new_lm_head_decoder = self._get_resized_lm_head_decoder(old_lm_head_decoder, new_num_tokens)
  1773. self.set_output_embeddings(new_lm_head_decoder)
  1774. self.set_input_embeddings(new_embeddings)
  1775. return self.get_input_embeddings()
  1776. def _v2_resize_token_embeddings(self, new_num_tokens):
  1777. old_embeddings = self.get_input_embeddings()
  1778. new_embeddings = self._v2_get_resized_embeddings(old_embeddings, new_num_tokens)
  1779. self.set_input_embeddings(new_embeddings)
  1780. # If word embeddings are not tied, make sure that lm head bias is resized as well
  1781. if self.get_bias() is not None:
  1782. old_lm_head_bias = self.get_bias()
  1783. new_lm_head_bias = self._v2_get_resized_lm_head_bias(old_lm_head_bias, new_num_tokens)
  1784. self.set_bias(new_lm_head_bias)
  1785. # If word embeddings are not tied, make sure that lm head decoder is resized as well.
  1786. tied_weights = self.get_input_embeddings() == self.get_output_embeddings()
  1787. if self.get_output_embeddings() is not None and not tied_weights:
  1788. old_lm_head_decoder = self._get_word_embedding_weight(self.get_output_embeddings())
  1789. # TODO (joao): this one probably needs a v2 version with other models
  1790. new_lm_head_decoder = self._get_resized_lm_head_decoder(old_lm_head_decoder, new_num_tokens)
  1791. self.set_output_embeddings(new_lm_head_decoder)
  1792. return self.get_input_embeddings()
  1793. def _get_resized_lm_head_bias(self, old_lm_head_bias, new_num_tokens):
  1794. """
  1795. Build a resized bias from the old ones. Increasing the size will add newly initialized vectors at the end.
  1796. Reducing the size will remove vectors from the end
  1797. Args:
  1798. old_lm_head_bias (`tf.Variable`):
  1799. Old lm head bias to be resized.
  1800. new_num_tokens (`int`, *optional*):
  1801. New number of tokens in the linear matrix.
  1802. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove
  1803. vectors from the end. If not provided or `None`, just returns None
  1804. Return:
  1805. `tf.Variable`: Pointer to the resized bias.
  1806. """
  1807. # TODO (joao): flagged for replacement (by `_v2_get_resized_lm_head_bias`) due to embeddings refactor
  1808. new_lm_head_bias = {}
  1809. for attr, weight in old_lm_head_bias.items():
  1810. first_dim, old_num_tokens = (None, shape_list(weight)[0]) if tf.rank(weight) == 1 else shape_list(weight)
  1811. size_diff = new_num_tokens - old_num_tokens
  1812. final_shape = [new_num_tokens] if first_dim is None else [first_dim, new_num_tokens]
  1813. # initialize new bias
  1814. if tf.math.greater(size_diff, 0):
  1815. padding_shape = [[0, size_diff]] if first_dim is None else [[0, 0], [0, size_diff]]
  1816. current_bias = tf.pad(weight.value(), tf.convert_to_tensor(padding_shape), constant_values=-1)
  1817. num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
  1818. mask_shape = [num_tokens_to_copy] if first_dim is None else [1, num_tokens_to_copy]
  1819. bias_mask = tf.fill(tf.convert_to_tensor(mask_shape), True)
  1820. bias_mask = tf.pad(bias_mask, tf.convert_to_tensor(padding_shape), constant_values=False)
  1821. else:
  1822. slice_from = [0] if first_dim is None else [0, 0]
  1823. current_bias = tf.slice(
  1824. weight.value(), tf.convert_to_tensor(slice_from), tf.convert_to_tensor(final_shape)
  1825. )
  1826. bias_mask = tf.fill(tf.convert_to_tensor(final_shape), True)
  1827. new_bias = self.add_weight(
  1828. shape=final_shape,
  1829. initializer="zeros",
  1830. trainable=True,
  1831. name=weight.name.split(":")[0],
  1832. )
  1833. init_bias = tf.where(bias_mask, current_bias, new_bias.value())
  1834. new_bias.assign(init_bias)
  1835. new_lm_head_bias[attr] = new_bias
  1836. return new_lm_head_bias
  1837. def _v2_get_resized_lm_head_bias(
  1838. self, old_lm_head_bias: dict[str, tf.Variable], new_num_tokens: int
  1839. ) -> dict[str, tf.Tensor]:
  1840. """
  1841. Build a resized bias from the old ones. Increasing the size will add newly initialized vectors at the end.
  1842. Reducing the size will remove vectors from the end
  1843. Args:
  1844. old_lm_head_bias (`dict[str, tf.Variable]`):
  1845. Old lm head bias to be resized.
  1846. new_num_tokens (`int`):
  1847. New number of tokens in the linear matrix. Increasing the size will add newly initialized vectors at
  1848. the end. Reducing the size will remove vectors from the end.
  1849. Return:
  1850. `tf.Tensor`: Values for the resized bias.
  1851. """
  1852. new_lm_head_bias = {}
  1853. for attr, weight in old_lm_head_bias.items():
  1854. # Determine the size difference (depending on the shape)
  1855. first_dim, old_num_tokens = (None, shape_list(weight)[0]) if tf.rank(weight) == 1 else shape_list(weight)
  1856. size_diff = new_num_tokens - old_num_tokens
  1857. # Copy the old bias values to the new bias
  1858. if old_num_tokens > new_num_tokens:
  1859. new_bias = weight.value()[..., :new_num_tokens]
  1860. else:
  1861. padding_shape = [[0, size_diff]] if first_dim is None else [[0, 0], [0, size_diff]]
  1862. new_bias = tf.pad(weight.value(), tf.convert_to_tensor(padding_shape))
  1863. new_lm_head_bias[attr] = new_bias
  1864. return new_lm_head_bias
  1865. def _get_resized_lm_head_decoder(self, old_lm_head_decoder, new_num_tokens):
  1866. """
  1867. Build a resized decoder from the old ones. Increasing the size will add newly initialized vectors at the end.
  1868. Reducing the size will remove vectors from the end
  1869. Args:
  1870. old_lm_head_decoder (`tf.Variable`):
  1871. Old lm head decoder to be resized.
  1872. new_num_tokens (`int`, *optional*):
  1873. New number of tokens in the linear matrix.
  1874. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove
  1875. vectors from the end. If not provided or `None`, just returns None
  1876. Return:
  1877. `tf.Variable`: Pointer to the resized decoder or None if the output embeddings are different from the input
  1878. ones.
  1879. """
  1880. new_lm_head_decoder = old_lm_head_decoder
  1881. is_input_output_equals = tf.reduce_any(
  1882. self._get_word_embedding_weight(self.get_input_embeddings()) == old_lm_head_decoder
  1883. )
  1884. if old_lm_head_decoder is not None and not is_input_output_equals:
  1885. old_embedding_dim = shape_list(old_lm_head_decoder)[1]
  1886. decoder_mask, current_decoder = init_copy_embeddings(old_lm_head_decoder, new_num_tokens)
  1887. new_lm_head_decoder = self.add_weight(
  1888. shape=(new_num_tokens, old_embedding_dim),
  1889. initializer="zeros",
  1890. trainable=True,
  1891. name=old_lm_head_decoder.name.split(":")[0],
  1892. )
  1893. init_decoder = tf.where(decoder_mask, current_decoder, new_lm_head_decoder.value())
  1894. new_lm_head_decoder.assign(init_decoder)
  1895. return new_lm_head_decoder
  1896. def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None) -> tf.Variable:
  1897. """
  1898. Build a resized Embedding weights from a provided token Embedding weights. Increasing the size will add newly
  1899. initialized vectors at the end. Reducing the size will remove vectors from the end
  1900. Args:
  1901. old_embeddings (`tf.Variable`):
  1902. Old embeddings to be resized.
  1903. new_num_tokens (`int`, *optional*):
  1904. New number of tokens in the embedding matrix.
  1905. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove
  1906. vectors from the end. If not provided or `None`, just returns a pointer to the input tokens
  1907. `tf.Variable` module of the model without doing anything.
  1908. Return:
  1909. `tf.Variable`: Pointer to the resized Embedding Module or the old Embedding Module if `new_num_tokens` is
  1910. `None`
  1911. """
  1912. # TODO (joao): flagged for replacement (by `_v2_get_resized_embeddings`) due to embeddings refactor
  1913. old_embedding_dim = shape_list(old_embeddings)[1]
  1914. init_range = getattr(self.config, "initializer_range", 0.02)
  1915. embeddings_mask, current_embeddings = init_copy_embeddings(old_embeddings, new_num_tokens)
  1916. new_embeddings = self.add_weight(
  1917. name=old_embeddings.name.split(":")[0],
  1918. shape=[new_num_tokens, old_embedding_dim],
  1919. initializer=get_initializer(init_range),
  1920. dtype=tf.float32,
  1921. )
  1922. init_embeddings = tf.where(embeddings_mask, current_embeddings, new_embeddings.value())
  1923. new_embeddings.assign(init_embeddings)
  1924. return new_embeddings
  1925. def _v2_get_resized_embeddings(
  1926. self, old_embeddings: keras.layers.Embedding, new_num_tokens: int
  1927. ) -> keras.layers.Embedding:
  1928. """
  1929. Build a resized Embedding layer from a provided Embedding layer. Increasing the size will add newly initialized
  1930. vectors at the end. Reducing the size will remove vectors from the end.
  1931. Args:
  1932. old_embeddings (`keras.layers.Embedding`):
  1933. Old embeddings to be resized.
  1934. new_num_tokens (`int`, *optional*):
  1935. New number of tokens in the embedding matrix.
  1936. Return:
  1937. `keras.layers.Embedding`: Resized Embedding layer.
  1938. """
  1939. # Get the initialization range for the embeddings
  1940. init_range = 0.02 # default value
  1941. potential_initialization_variable_names = [
  1942. "initializer_range", # most common
  1943. "initializer_factor", # e.g. T5
  1944. "init_std", # e.g BART
  1945. ]
  1946. for var_name in potential_initialization_variable_names:
  1947. if hasattr(self.config, var_name):
  1948. init_range = getattr(self.config, var_name)
  1949. # Get a new (initialized) embeddings layer
  1950. new_embeddings = keras.layers.Embedding(
  1951. input_dim=new_num_tokens,
  1952. output_dim=old_embeddings.output_dim,
  1953. embeddings_initializer=keras.initializers.TruncatedNormal(stddev=init_range),
  1954. name=old_embeddings.embeddings.name[:-13], # exact same scoped name except "/embeddings:0"
  1955. )
  1956. new_embeddings(tf.constant([[0]]))
  1957. # Copy the old embeddings to the new embeddings
  1958. if old_embeddings.input_dim >= new_num_tokens:
  1959. init_embeddings = old_embeddings.embeddings[:new_num_tokens]
  1960. else:
  1961. init_embeddings = tf.concat(
  1962. [old_embeddings.embeddings, new_embeddings.embeddings[old_embeddings.input_dim :]], axis=0
  1963. )
  1964. new_embeddings.embeddings.assign(init_embeddings)
  1965. return new_embeddings
  1966. def prune_heads(self, heads_to_prune):
  1967. """
  1968. Prunes heads of the base model.
  1969. Arguments:
  1970. heads_to_prune (`dict[int, list[int]]`):
  1971. Dictionary with keys being selected layer indices (`int`) and associated values being the list of heads
  1972. to prune in said layer (list of `int`). For instance {1: [0, 2], 2: [2, 3]} will prune heads 0 and 2 on
  1973. layer 1 and heads 2 and 3 on layer 2.
  1974. """
  1975. raise NotImplementedError
  1976. def save_pretrained(
  1977. self,
  1978. save_directory,
  1979. saved_model=False,
  1980. version=1,
  1981. push_to_hub=False,
  1982. signatures=None,
  1983. max_shard_size: int | str = "5GB",
  1984. create_pr: bool = False,
  1985. safe_serialization: bool = False,
  1986. token: str | bool | None = None,
  1987. **kwargs,
  1988. ):
  1989. """
  1990. Save a model and its configuration file to a directory, so that it can be re-loaded using the
  1991. [`~TFPreTrainedModel.from_pretrained`] class method.
  1992. Arguments:
  1993. save_directory (`str`):
  1994. Directory to which to save. Will be created if it doesn't exist.
  1995. saved_model (`bool`, *optional*, defaults to `False`):
  1996. If the model has to be saved in saved model format as well or not.
  1997. version (`int`, *optional*, defaults to 1):
  1998. The version of the saved model. A saved model needs to be versioned in order to be properly loaded by
  1999. TensorFlow Serving as detailed in the official documentation
  2000. https://www.tensorflow.org/tfx/serving/serving_basic
  2001. push_to_hub (`bool`, *optional*, defaults to `False`):
  2002. Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
  2003. repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
  2004. namespace).
  2005. signatures (`dict` or `tf.function`, *optional*):
  2006. Model's signature used for serving. This will be passed to the `signatures` argument of model.save().
  2007. max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
  2008. The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
  2009. lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`).
  2010. <Tip warning={true}>
  2011. If a single weight of the model is bigger than `max_shard_size`, it will be in its own checkpoint shard
  2012. which will be bigger than `max_shard_size`.
  2013. </Tip>
  2014. create_pr (`bool`, *optional*, defaults to `False`):
  2015. Whether or not to create a PR with the uploaded files or directly commit.
  2016. safe_serialization (`bool`, *optional*, defaults to `False`):
  2017. Whether to save the model using `safetensors` or the traditional TensorFlow way (that uses `h5`).
  2018. token (`str` or `bool`, *optional*):
  2019. The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
  2020. the token generated when running `hf auth login` (stored in `~/.huggingface`).
  2021. kwargs (`dict[str, Any]`, *optional*):
  2022. Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
  2023. """
  2024. use_auth_token = kwargs.pop("use_auth_token", None)
  2025. if use_auth_token is not None:
  2026. warnings.warn(
  2027. "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
  2028. FutureWarning,
  2029. )
  2030. if token is not None:
  2031. raise ValueError(
  2032. "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
  2033. )
  2034. token = use_auth_token
  2035. if token is not None:
  2036. kwargs["token"] = token
  2037. if os.path.isfile(save_directory):
  2038. logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
  2039. return
  2040. os.makedirs(save_directory, exist_ok=True)
  2041. if push_to_hub:
  2042. commit_message = kwargs.pop("commit_message", None)
  2043. repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
  2044. repo_id = self._create_repo(repo_id, **kwargs)
  2045. files_timestamps = self._get_files_timestamps(save_directory)
  2046. if saved_model:
  2047. # If `torch_dtype` is in the config with a torch dtype class as the value, we need to change it to string.
  2048. # (Although TF doesn't care about this attribute, we can't just remove it or set it to `None`.)
  2049. if getattr(self.config, "torch_dtype", None) is not None and not isinstance(self.config.torch_dtype, str):
  2050. self.config.torch_dtype = str(self.config.torch_dtype).split(".")[1]
  2051. if signatures is None:
  2052. serving_default = self.serving.get_concrete_function(self.input_signature)
  2053. if any(spec.dtype == tf.int32 for spec in self.input_signature.values()):
  2054. int64_spec = {
  2055. key: tf.TensorSpec(
  2056. shape=spec.shape, dtype=tf.int64 if spec.dtype == tf.int32 else spec.dtype, name=spec.name
  2057. )
  2058. for key, spec in self.input_signature.items()
  2059. }
  2060. int64_serving = self.serving.get_concrete_function(int64_spec)
  2061. signatures = {"serving_default": serving_default, "int64_serving": int64_serving}
  2062. else:
  2063. signatures = serving_default
  2064. saved_model_dir = os.path.join(save_directory, "saved_model", str(version))
  2065. self.save(saved_model_dir, include_optimizer=False, signatures=signatures)
  2066. logger.info(f"Saved model created in {saved_model_dir}")
  2067. # Save configuration file
  2068. self.config.architectures = [self.__class__.__name__[2:]]
  2069. # If we have a custom model, we copy the file defining it in the folder and set the attributes so it can be
  2070. # loaded from the Hub.
  2071. if self._auto_class is not None:
  2072. custom_object_save(self, save_directory, config=self.config)
  2073. self.config.save_pretrained(save_directory)
  2074. if self.can_generate():
  2075. self.generation_config.save_pretrained(save_directory)
  2076. # If we save using the predefined names, we can load using `from_pretrained`
  2077. weights_name = SAFE_WEIGHTS_NAME if safe_serialization else TF2_WEIGHTS_NAME
  2078. output_model_file = os.path.join(save_directory, weights_name)
  2079. shards, index = tf_shard_checkpoint(self.weights, max_shard_size, weights_name=weights_name)
  2080. # Clean the folder from a previous save
  2081. for filename in os.listdir(save_directory):
  2082. full_filename = os.path.join(save_directory, filename)
  2083. # If we have a shard file that is not going to be replaced, we delete it, but only from the main process
  2084. # in distributed settings to avoid race conditions.
  2085. weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "")
  2086. if filename.startswith(weights_no_suffix) and os.path.isfile(full_filename) and filename not in shards:
  2087. os.remove(full_filename)
  2088. if index is None:
  2089. if safe_serialization:
  2090. state_dict = {strip_model_name_and_prefix(w.name): w.value() for w in self.weights}
  2091. safe_save_file(state_dict, output_model_file, metadata={"format": "tf"})
  2092. else:
  2093. self.save_weights(output_model_file)
  2094. logger.info(f"Model weights saved in {output_model_file}")
  2095. else:
  2096. save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else TF2_WEIGHTS_INDEX_NAME
  2097. save_index_file = os.path.join(save_directory, save_index_file)
  2098. # Save the index as well
  2099. with open(save_index_file, "w", encoding="utf-8") as index_file:
  2100. content = json.dumps(index, indent=2, sort_keys=True) + "\n"
  2101. index_file.write(content)
  2102. logger.info(
  2103. f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
  2104. f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the "
  2105. f"index located at {save_index_file}."
  2106. )
  2107. for shard_file, shard in shards.items():
  2108. if safe_serialization:
  2109. shard_state_dict = {strip_model_name_and_prefix(w.name): w.value() for w in shard}
  2110. safe_save_file(
  2111. shard_state_dict, os.path.join(save_directory, shard_file), metadata={"format": "tf"}
  2112. )
  2113. else:
  2114. with h5py.File(os.path.join(save_directory, shard_file), mode="w") as shard_file:
  2115. layers = []
  2116. for layer in sorted(shard, key=lambda x: x.name):
  2117. if "model." in layer.name or len(layer.name.split("/")) == 1:
  2118. layer_name = layer.name
  2119. else:
  2120. layer_name = "/".join(layer.name.split("/")[1:])
  2121. param_dset = shard_file.create_dataset(
  2122. layer_name, layer.numpy().shape, dtype=layer.numpy().dtype
  2123. )
  2124. param_dset[:] = layer.numpy()
  2125. layers.append(layer_name.encode("utf8"))
  2126. save_attributes_to_hdf5_group(shard_file, "layer_names", layers)
  2127. if push_to_hub:
  2128. self._upload_modified_files(
  2129. save_directory,
  2130. repo_id,
  2131. files_timestamps,
  2132. commit_message=commit_message,
  2133. token=token,
  2134. )
  2135. @classmethod
  2136. def from_pretrained(
  2137. cls,
  2138. pretrained_model_name_or_path: str | os.PathLike | None,
  2139. *model_args,
  2140. config: PretrainedConfig | str | os.PathLike | None = None,
  2141. cache_dir: str | os.PathLike | None = None,
  2142. ignore_mismatched_sizes: bool = False,
  2143. force_download: bool = False,
  2144. local_files_only: bool = False,
  2145. token: str | bool | None = None,
  2146. revision: str = "main",
  2147. use_safetensors: bool | None = None,
  2148. **kwargs,
  2149. ):
  2150. r"""
  2151. Instantiate a pretrained TF 2.0 model from a pre-trained model configuration.
  2152. The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
  2153. pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
  2154. task.
  2155. The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
  2156. weights are discarded.
  2157. Parameters:
  2158. pretrained_model_name_or_path (`str`, *optional*):
  2159. Can be either:
  2160. - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
  2161. - A path to a *directory* containing model weights saved using
  2162. [`~TFPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
  2163. - A path or url to a *PyTorch state_dict save file* (e.g, `./pt_model/pytorch_model.bin`). In this
  2164. case, `from_pt` should be set to `True` and a configuration object should be provided as `config`
  2165. argument. This loading path is slower than converting the PyTorch model in a TensorFlow model
  2166. using the provided conversion scripts and loading the TensorFlow model afterwards.
  2167. - `None` if you are both providing the configuration and state dictionary (resp. with keyword
  2168. arguments `config` and `state_dict`).
  2169. model_args (sequence of positional arguments, *optional*):
  2170. All remaining positional arguments will be passed to the underlying model's `__init__` method.
  2171. config (`Union[PretrainedConfig, str]`, *optional*):
  2172. Can be either:
  2173. - an instance of a class derived from [`PretrainedConfig`],
  2174. - a string valid as input to [`~PretrainedConfig.from_pretrained`].
  2175. Configuration for the model to use instead of an automatically loaded configuration. Configuration can
  2176. be automatically loaded when:
  2177. - The model is a model provided by the library (loaded with the *model id* string of a pretrained
  2178. model).
  2179. - The model was saved using [`~TFPreTrainedModel.save_pretrained`] and is reloaded by supplying the
  2180. save directory.
  2181. - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a
  2182. configuration JSON file named *config.json* is found in the directory.
  2183. from_pt (`bool`, *optional*, defaults to `False`):
  2184. Load the model weights from a PyTorch state_dict save file (see docstring of
  2185. `pretrained_model_name_or_path` argument).
  2186. ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`):
  2187. Whether or not to raise an error if some of the weights from the checkpoint do not have the same size
  2188. as the weights of the model (if for instance, you are instantiating a model with 10 labels from a
  2189. checkpoint with 3 labels).
  2190. cache_dir (`str`, *optional*):
  2191. Path to a directory in which a downloaded pretrained model configuration should be cached if the
  2192. standard cache should not be used.
  2193. force_download (`bool`, *optional*, defaults to `False`):
  2194. Whether or not to force the (re-)download of the model weights and configuration files, overriding the
  2195. cached versions if they exist.
  2196. resume_download:
  2197. Deprecated and ignored. All downloads are now resumed by default when possible.
  2198. Will be removed in v5 of Transformers.
  2199. proxies:
  2200. (`dict[str, str], `optional`): A dictionary of proxy servers to use by protocol or endpoint, e.g.,
  2201. `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
  2202. output_loading_info(`bool`, *optional*, defaults to `False`): Whether ot not to also return a
  2203. dictionary containing missing keys, unexpected keys and error messages.
  2204. local_files_only(`bool`, *optional*, defaults to `False`):
  2205. Whether or not to only look at local files (e.g., not try downloading the model).
  2206. token (`str` or `bool`, *optional*):
  2207. The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
  2208. the token generated when running `hf auth login` (stored in `~/.huggingface`).
  2209. revision (`str`, *optional*, defaults to `"main"`):
  2210. The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
  2211. git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
  2212. identifier allowed by git.
  2213. <Tip>
  2214. To test a pull request you made on the Hub, you can pass `revision="refs/pr/<pr_number>"`.
  2215. </Tip>
  2216. mirror (`str`, *optional*):
  2217. Mirror source to accelerate downloads in China. If you are from China and have an accessibility
  2218. problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
  2219. Please refer to the mirror site for more information.
  2220. subfolder (`str`, *optional*, defaults to `""`):
  2221. In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
  2222. specify the folder name here.
  2223. tf_to_pt_weight_rename (`Callable`, *optional*):
  2224. A function that is called to transform the names of weights during the PyTorch to TensorFlow
  2225. crossloading process. This is not necessary for most models, but is useful to allow composite models to
  2226. be crossloaded correctly.
  2227. use_safetensors (`bool`, *optional*, defaults to `None`):
  2228. Whether or not to use `safetensors` checkpoints. Defaults to `None`. If not specified and `safetensors`
  2229. is not installed, it will be set to `False`.
  2230. kwargs (remaining dictionary of keyword arguments, *optional*):
  2231. Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
  2232. `output_attentions=True`). Behaves differently depending on whether a `config` is provided or
  2233. automatically loaded:
  2234. - If a configuration is provided with `config`, `**kwargs` will be directly passed to the
  2235. underlying model's `__init__` method (we assume all relevant updates to the configuration have
  2236. already been done)
  2237. - If a configuration is not provided, `kwargs` will be first passed to the configuration class
  2238. initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that
  2239. corresponds to a configuration attribute will be used to override said attribute with the
  2240. supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute
  2241. will be passed to the underlying model's `__init__` function.
  2242. Examples:
  2243. ```python
  2244. >>> from transformers import BertConfig, TFBertModel
  2245. >>> # Download model and configuration from huggingface.co and cache.
  2246. >>> model = TFBertModel.from_pretrained("google-bert/bert-base-uncased")
  2247. >>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable).
  2248. >>> model = TFBertModel.from_pretrained("./test/saved_model/")
  2249. >>> # Update configuration during loading.
  2250. >>> model = TFBertModel.from_pretrained("google-bert/bert-base-uncased", output_attentions=True)
  2251. >>> assert model.config.output_attentions == True
  2252. >>> # Loading from a Pytorch model file instead of a TensorFlow checkpoint (slower, for example purposes, not runnable).
  2253. >>> config = BertConfig.from_json_file("./pt_model/my_pt_model_config.json")
  2254. >>> model = TFBertModel.from_pretrained("./pt_model/my_pytorch_model.bin", from_pt=True, config=config)
  2255. ```"""
  2256. from_pt = kwargs.pop("from_pt", False)
  2257. resume_download = kwargs.pop("resume_download", None)
  2258. proxies = kwargs.pop("proxies", None)
  2259. output_loading_info = kwargs.pop("output_loading_info", False)
  2260. use_auth_token = kwargs.pop("use_auth_token", None)
  2261. trust_remote_code = kwargs.pop("trust_remote_code", None)
  2262. _ = kwargs.pop("mirror", None)
  2263. load_weight_prefix = kwargs.pop("load_weight_prefix", None)
  2264. from_pipeline = kwargs.pop("_from_pipeline", None)
  2265. from_auto_class = kwargs.pop("_from_auto", False)
  2266. subfolder = kwargs.pop("subfolder", "")
  2267. commit_hash = kwargs.pop("_commit_hash", None)
  2268. tf_to_pt_weight_rename = kwargs.pop("tf_to_pt_weight_rename", None)
  2269. # Not relevant for TF models
  2270. _ = kwargs.pop("adapter_kwargs", None)
  2271. if use_auth_token is not None:
  2272. warnings.warn(
  2273. "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
  2274. FutureWarning,
  2275. )
  2276. if token is not None:
  2277. raise ValueError(
  2278. "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
  2279. )
  2280. token = use_auth_token
  2281. if trust_remote_code is True:
  2282. logger.warning(
  2283. "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is"
  2284. " ignored."
  2285. )
  2286. user_agent = {"file_type": "model", "framework": "tensorflow", "from_auto_class": from_auto_class}
  2287. if from_pipeline is not None:
  2288. user_agent["using_pipeline"] = from_pipeline
  2289. if is_offline_mode() and not local_files_only:
  2290. logger.info("Offline mode: forcing local_files_only=True")
  2291. local_files_only = True
  2292. if use_safetensors is None and not is_safetensors_available():
  2293. use_safetensors = False
  2294. # Load config if we don't provide a configuration
  2295. if not isinstance(config, PretrainedConfig):
  2296. config_path = config if config is not None else pretrained_model_name_or_path
  2297. config, model_kwargs = cls.config_class.from_pretrained(
  2298. config_path,
  2299. cache_dir=cache_dir,
  2300. return_unused_kwargs=True,
  2301. force_download=force_download,
  2302. resume_download=resume_download,
  2303. proxies=proxies,
  2304. local_files_only=local_files_only,
  2305. token=token,
  2306. revision=revision,
  2307. _from_auto=from_auto_class,
  2308. _from_pipeline=from_pipeline,
  2309. _commit_hash=commit_hash,
  2310. **kwargs,
  2311. )
  2312. else:
  2313. model_kwargs = kwargs
  2314. if commit_hash is None:
  2315. commit_hash = getattr(config, "_commit_hash", None)
  2316. # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
  2317. # index of the files.
  2318. is_sharded = False
  2319. # Load model
  2320. if pretrained_model_name_or_path is not None:
  2321. pretrained_model_name_or_path = str(pretrained_model_name_or_path)
  2322. is_local = os.path.isdir(pretrained_model_name_or_path)
  2323. if is_local:
  2324. if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
  2325. # Load from a PyTorch checkpoint in priority if from_pt
  2326. archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
  2327. elif from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)):
  2328. # Load from a sharded PyTorch checkpoint
  2329. archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)
  2330. is_sharded = True
  2331. elif use_safetensors is not False and os.path.isfile(
  2332. os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME)
  2333. ):
  2334. # Load from a safetensors checkpoint
  2335. archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME)
  2336. elif use_safetensors is not False and os.path.isfile(
  2337. os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
  2338. ):
  2339. # Load from a sharded safetensors checkpoint
  2340. archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
  2341. is_sharded = True
  2342. elif os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
  2343. # Load from a TF 2.0 checkpoint
  2344. archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
  2345. elif os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME)):
  2346. # Load from a sharded TF 2.0 checkpoint
  2347. archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME)
  2348. is_sharded = True
  2349. # At this stage we don't have a weight file so we will raise an error.
  2350. elif use_safetensors:
  2351. raise OSError(
  2352. f"Error no file named {SAFE_WEIGHTS_NAME} or {SAFE_WEIGHTS_INDEX_NAME} found in directory {pretrained_model_name_or_path}. "
  2353. f"Please make sure that the model has been saved with `safe_serialization=True` or do not "
  2354. f"set `use_safetensors=True`."
  2355. )
  2356. elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)) or os.path.isfile(
  2357. os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)
  2358. ):
  2359. raise OSError(
  2360. f"Error no file named {TF2_WEIGHTS_NAME} or {SAFE_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} "
  2361. "but there is a file for PyTorch weights. Use `from_pt=True` to load this model from those "
  2362. "weights."
  2363. )
  2364. else:
  2365. raise OSError(
  2366. f"Error no file named {TF2_WEIGHTS_NAME}, {SAFE_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
  2367. f"{pretrained_model_name_or_path}."
  2368. )
  2369. elif os.path.isfile(pretrained_model_name_or_path):
  2370. archive_file = pretrained_model_name_or_path
  2371. is_local = True
  2372. elif os.path.isfile(pretrained_model_name_or_path + ".index"):
  2373. archive_file = pretrained_model_name_or_path + ".index"
  2374. is_local = True
  2375. elif is_remote_url(pretrained_model_name_or_path):
  2376. filename = pretrained_model_name_or_path
  2377. resolved_archive_file = download_url(pretrained_model_name_or_path)
  2378. else:
  2379. # set correct filename
  2380. if from_pt:
  2381. filename = WEIGHTS_NAME
  2382. elif use_safetensors is not False:
  2383. filename = SAFE_WEIGHTS_NAME
  2384. else:
  2385. filename = TF2_WEIGHTS_NAME
  2386. try:
  2387. # Load from URL or cache if already cached
  2388. cached_file_kwargs = {
  2389. "cache_dir": cache_dir,
  2390. "force_download": force_download,
  2391. "proxies": proxies,
  2392. "resume_download": resume_download,
  2393. "local_files_only": local_files_only,
  2394. "token": token,
  2395. "user_agent": user_agent,
  2396. "revision": revision,
  2397. "subfolder": subfolder,
  2398. "_raise_exceptions_for_gated_repo": False,
  2399. "_raise_exceptions_for_missing_entries": False,
  2400. "_commit_hash": commit_hash,
  2401. }
  2402. resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
  2403. # Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None
  2404. # result when internet is up, the repo and revision exist, but the file does not.
  2405. if resolved_archive_file is None and filename == SAFE_WEIGHTS_NAME:
  2406. # Did not find the safetensors file, let's fallback to TF.
  2407. # No support for sharded safetensors yet, so we'll raise an error if that's all we find.
  2408. filename = TF2_WEIGHTS_NAME
  2409. resolved_archive_file = cached_file(
  2410. pretrained_model_name_or_path, TF2_WEIGHTS_NAME, **cached_file_kwargs
  2411. )
  2412. if resolved_archive_file is None and filename == TF2_WEIGHTS_NAME:
  2413. # Maybe the checkpoint is sharded, we try to grab the index name in this case.
  2414. resolved_archive_file = cached_file(
  2415. pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME, **cached_file_kwargs
  2416. )
  2417. if resolved_archive_file is not None:
  2418. is_sharded = True
  2419. if resolved_archive_file is None and filename == WEIGHTS_NAME:
  2420. # Maybe the checkpoint is sharded, we try to grab the index name in this case.
  2421. resolved_archive_file = cached_file(
  2422. pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **cached_file_kwargs
  2423. )
  2424. if resolved_archive_file is not None:
  2425. is_sharded = True
  2426. if resolved_archive_file is None:
  2427. # Otherwise, maybe there is a PyTorch or Flax model file. We try those to give a helpful error
  2428. # message.
  2429. has_file_kwargs = {
  2430. "revision": revision,
  2431. "proxies": proxies,
  2432. "token": token,
  2433. "cache_dir": cache_dir,
  2434. "local_files_only": local_files_only,
  2435. }
  2436. if has_file(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME, **has_file_kwargs):
  2437. is_sharded = True
  2438. elif has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs):
  2439. raise OSError(
  2440. f"{pretrained_model_name_or_path} does not appear to have a file named"
  2441. f" {TF2_WEIGHTS_NAME} but there is a file for PyTorch weights. Use `from_pt=True` to"
  2442. " load this model from those weights."
  2443. )
  2444. else:
  2445. raise OSError(
  2446. f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME},"
  2447. f" {TF2_WEIGHTS_NAME} or {TF_WEIGHTS_NAME}"
  2448. )
  2449. except OSError:
  2450. # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted
  2451. # to the original exception.
  2452. raise
  2453. except Exception:
  2454. # For any other exception, we throw a generic error.
  2455. raise OSError(
  2456. f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it"
  2457. " from 'https://huggingface.co/models', make sure you don't have a local directory with the"
  2458. f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
  2459. f" directory containing a file named {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME} or {TF_WEIGHTS_NAME}"
  2460. )
  2461. if is_local:
  2462. logger.info(f"loading weights file {archive_file}")
  2463. resolved_archive_file = archive_file
  2464. filename = resolved_archive_file.split(os.path.sep)[-1]
  2465. else:
  2466. logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}")
  2467. else:
  2468. resolved_archive_file = None
  2469. # We'll need to download and cache each checkpoint shard if the checkpoint is sharded.
  2470. if is_sharded:
  2471. # resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case.
  2472. resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(
  2473. pretrained_model_name_or_path,
  2474. resolved_archive_file,
  2475. cache_dir=cache_dir,
  2476. force_download=force_download,
  2477. proxies=proxies,
  2478. resume_download=resume_download,
  2479. local_files_only=local_files_only,
  2480. token=token,
  2481. user_agent=user_agent,
  2482. revision=revision,
  2483. _commit_hash=commit_hash,
  2484. )
  2485. safetensors_from_pt = False
  2486. if filename == SAFE_WEIGHTS_NAME:
  2487. with safe_open(resolved_archive_file, framework="tf") as f:
  2488. safetensors_metadata = f.metadata()
  2489. if safetensors_metadata is None or safetensors_metadata.get("format") not in ["pt", "tf", "flax", "mlx"]:
  2490. raise OSError(
  2491. f"The safetensors archive passed at {resolved_archive_file} does not contain the valid metadata."
  2492. " Make sure you save your model with the `save_pretrained` method."
  2493. )
  2494. safetensors_from_pt = safetensors_metadata.get("format") == "pt"
  2495. elif filename == SAFE_WEIGHTS_INDEX_NAME:
  2496. with safe_open(resolved_archive_file[0], framework="tf") as f:
  2497. safetensors_metadata = f.metadata()
  2498. if safetensors_metadata is None or safetensors_metadata.get("format") not in ["pt", "tf", "flax", "mlx"]:
  2499. raise OSError(
  2500. f"The safetensors archive passed at {resolved_archive_file} does not contain the valid metadata."
  2501. " Make sure you save your model with the `save_pretrained` method."
  2502. )
  2503. safetensors_from_pt = safetensors_metadata.get("format") == "pt"
  2504. config.name_or_path = pretrained_model_name_or_path
  2505. # composed models, *e.g.* TFRag, require special treatment when it comes to loading
  2506. # pre-trained weights.
  2507. if cls._requires_load_weight_prefix and model_kwargs.get("name") is not None:
  2508. model_kwargs["load_weight_prefix"] = load_weight_prefix + "/" + model_kwargs.get("name")
  2509. # Instantiate model.
  2510. model = cls(config, *model_args, **model_kwargs)
  2511. if tf_to_pt_weight_rename is None and hasattr(model, "tf_to_pt_weight_rename"):
  2512. # TODO Matt: This is a temporary workaround to allow weight renaming, but requires a method
  2513. # to be defined for each class that requires a rename. We can probably just have a class-level
  2514. # dict and a single top-level method or something and cut down a lot of boilerplate code
  2515. tf_to_pt_weight_rename = model.tf_to_pt_weight_rename
  2516. if from_pt:
  2517. from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
  2518. # Load from a PyTorch checkpoint
  2519. return load_pytorch_checkpoint_in_tf2_model(
  2520. model,
  2521. resolved_archive_file,
  2522. allow_missing_keys=True,
  2523. output_loading_info=output_loading_info,
  2524. _prefix=load_weight_prefix,
  2525. tf_to_pt_weight_rename=tf_to_pt_weight_rename,
  2526. )
  2527. # we might need to extend the variable scope for composite models
  2528. if load_weight_prefix is not None:
  2529. with tf.compat.v1.variable_scope(load_weight_prefix):
  2530. model.build_in_name_scope() # build the network with dummy inputs
  2531. else:
  2532. model.build_in_name_scope() # build the network with dummy inputs
  2533. if safetensors_from_pt and not is_sharded:
  2534. from .modeling_tf_pytorch_utils import load_pytorch_state_dict_in_tf2_model
  2535. with safe_open(resolved_archive_file, framework="tf") as safetensors_archive:
  2536. # Load from a PyTorch safetensors checkpoint
  2537. # We load in TF format here because PT weights often need to be transposed, and this is much
  2538. # faster on GPU. Loading as numpy and transposing on CPU adds several seconds to load times.
  2539. return load_pytorch_state_dict_in_tf2_model(
  2540. model,
  2541. safetensors_archive,
  2542. tf_inputs=False, # No need to build the model again
  2543. allow_missing_keys=True,
  2544. output_loading_info=output_loading_info,
  2545. _prefix=load_weight_prefix,
  2546. ignore_mismatched_sizes=ignore_mismatched_sizes,
  2547. tf_to_pt_weight_rename=tf_to_pt_weight_rename,
  2548. )
  2549. elif safetensors_from_pt:
  2550. from .modeling_tf_pytorch_utils import load_sharded_pytorch_safetensors_in_tf2_model
  2551. return load_sharded_pytorch_safetensors_in_tf2_model(
  2552. model,
  2553. resolved_archive_file,
  2554. tf_inputs=False,
  2555. allow_missing_keys=True,
  2556. output_loading_info=output_loading_info,
  2557. _prefix=load_weight_prefix,
  2558. ignore_mismatched_sizes=ignore_mismatched_sizes,
  2559. tf_to_pt_weight_rename=tf_to_pt_weight_rename,
  2560. )
  2561. # 'by_name' allow us to do transfer learning by skipping/adding layers
  2562. # see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
  2563. try:
  2564. if is_sharded:
  2565. for file in resolved_archive_file:
  2566. os.path.isfile(file), f"Error retrieving files {file}"
  2567. if filename == SAFE_WEIGHTS_INDEX_NAME:
  2568. missing_keys, unexpected_keys, mismatched_keys = load_tf_sharded_weights_from_safetensors(
  2569. model,
  2570. resolved_archive_file,
  2571. ignore_mismatched_sizes=ignore_mismatched_sizes,
  2572. _prefix=load_weight_prefix,
  2573. )
  2574. else:
  2575. missing_keys, unexpected_keys, mismatched_keys = load_tf_sharded_weights(
  2576. model,
  2577. resolved_archive_file,
  2578. ignore_mismatched_sizes=ignore_mismatched_sizes,
  2579. _prefix=load_weight_prefix,
  2580. )
  2581. else:
  2582. # Handles both H5 and safetensors
  2583. missing_keys, unexpected_keys, mismatched_keys = load_tf_weights(
  2584. model,
  2585. resolved_archive_file,
  2586. ignore_mismatched_sizes=ignore_mismatched_sizes,
  2587. _prefix=load_weight_prefix,
  2588. )
  2589. except OSError as e:
  2590. try:
  2591. with open(resolved_archive_file) as f:
  2592. if f.read().startswith("version"):
  2593. raise OSError(
  2594. "You seem to have cloned a repository without having git-lfs installed. Please install "
  2595. "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
  2596. "you cloned."
  2597. )
  2598. else:
  2599. raise ValueError from e
  2600. except (UnicodeDecodeError, ValueError):
  2601. raise OSError(
  2602. "Unable to load weights from h5 file. "
  2603. "If you tried to load a TF 2.0 model from a PyTorch checkpoint, please set from_pt=True. "
  2604. )
  2605. if cls._keys_to_ignore_on_load_missing is not None:
  2606. for pat in cls._keys_to_ignore_on_load_missing:
  2607. missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
  2608. if cls._keys_to_ignore_on_load_unexpected is not None:
  2609. for pat in cls._keys_to_ignore_on_load_unexpected:
  2610. unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
  2611. if len(unexpected_keys) > 0:
  2612. logger.warning(
  2613. f"Some layers from the model checkpoint at {pretrained_model_name_or_path} were not used when"
  2614. f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
  2615. f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
  2616. " with another architecture (e.g. initializing a BertForSequenceClassification model from a"
  2617. " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
  2618. f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical"
  2619. " (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
  2620. )
  2621. else:
  2622. logger.warning(f"All model checkpoint layers were used when initializing {model.__class__.__name__}.\n")
  2623. if len(missing_keys) > 0:
  2624. logger.warning(
  2625. f"Some layers of {model.__class__.__name__} were not initialized from the model checkpoint at"
  2626. f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
  2627. " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
  2628. )
  2629. elif len(mismatched_keys) == 0:
  2630. logger.warning(
  2631. f"All the layers of {model.__class__.__name__} were initialized from the model checkpoint at"
  2632. f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint"
  2633. f" was trained on, you can already use {model.__class__.__name__} for predictions without further"
  2634. " training."
  2635. )
  2636. if len(mismatched_keys) > 0:
  2637. mismatched_warning = "\n".join(
  2638. [
  2639. f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
  2640. for key, shape1, shape2 in mismatched_keys
  2641. ]
  2642. )
  2643. logger.warning(
  2644. f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
  2645. f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
  2646. f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able"
  2647. " to use it for predictions and inference."
  2648. )
  2649. # If it is a model with generation capabilities, attempt to load the generation config
  2650. if model.can_generate():
  2651. try:
  2652. model.generation_config = GenerationConfig.from_pretrained(
  2653. pretrained_model_name_or_path,
  2654. cache_dir=cache_dir,
  2655. force_download=force_download,
  2656. resume_download=resume_download,
  2657. proxies=proxies,
  2658. local_files_only=local_files_only,
  2659. token=token,
  2660. revision=revision,
  2661. subfolder=subfolder,
  2662. _from_auto=from_auto_class,
  2663. _from_pipeline=from_pipeline,
  2664. **kwargs,
  2665. )
  2666. except OSError:
  2667. logger.info(
  2668. "Generation config file not found, using a generation config created from the model config."
  2669. )
  2670. pass
  2671. if output_loading_info:
  2672. loading_info = {
  2673. "missing_keys": missing_keys,
  2674. "unexpected_keys": unexpected_keys,
  2675. "mismatched_keys": mismatched_keys,
  2676. }
  2677. return model, loading_info
  2678. return model
  2679. def push_to_hub(
  2680. self,
  2681. repo_id: str,
  2682. use_temp_dir: bool | None = None,
  2683. commit_message: str | None = None,
  2684. private: bool | None = None,
  2685. max_shard_size: int | str | None = "10GB",
  2686. token: bool | str | None = None,
  2687. # (`use_auth_token` is deprecated: we have to keep it here as we don't have **kwargs)
  2688. use_auth_token: bool | str | None = None,
  2689. create_pr: bool = False,
  2690. **base_model_card_args,
  2691. ) -> str:
  2692. """
  2693. Upload the model files to the 🤗 Model Hub while synchronizing a local clone of the repo in `repo_path_or_name`.
  2694. Parameters:
  2695. repo_id (`str`):
  2696. The name of the repository you want to push your model to. It should contain your organization name
  2697. when pushing to a given organization.
  2698. use_temp_dir (`bool`, *optional*):
  2699. Whether or not to use a temporary directory to store the files saved before they are pushed to the Hub.
  2700. Will default to `True` if there is no directory named like `repo_id`, `False` otherwise.
  2701. commit_message (`str`, *optional*):
  2702. Message to commit while pushing. Will default to `"Upload model"`.
  2703. private (`bool`, *optional*):
  2704. Whether to make the repo private. If `None` (default), the repo will be public unless the organization's default is private. This value is ignored if the repo already exists.
  2705. token (`bool` or `str`, *optional*):
  2706. The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
  2707. when running `hf auth login` (stored in `~/.huggingface`). Will default to `True` if `repo_url`
  2708. is not specified.
  2709. max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
  2710. Only applicable for models. The maximum size for a checkpoint before being sharded. Checkpoints shard
  2711. will then be each of size lower than this size. If expressed as a string, needs to be digits followed
  2712. by a unit (like `"5MB"`).
  2713. create_pr (`bool`, *optional*, defaults to `False`):
  2714. Whether or not to create a PR with the uploaded files or directly commit.
  2715. Examples:
  2716. ```python
  2717. from transformers import TFAutoModel
  2718. model = TFAutoModel.from_pretrained("google-bert/bert-base-cased")
  2719. # Push the model to your namespace with the name "my-finetuned-bert".
  2720. model.push_to_hub("my-finetuned-bert")
  2721. # Push the model to an organization with the name "my-finetuned-bert".
  2722. model.push_to_hub("huggingface/my-finetuned-bert")
  2723. ```
  2724. """
  2725. if use_auth_token is not None:
  2726. warnings.warn(
  2727. "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
  2728. FutureWarning,
  2729. )
  2730. if token is not None:
  2731. raise ValueError(
  2732. "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
  2733. )
  2734. token = use_auth_token
  2735. if "repo_path_or_name" in base_model_card_args:
  2736. warnings.warn(
  2737. "The `repo_path_or_name` argument is deprecated and will be removed in v5 of Transformers. Use "
  2738. "`repo_id` instead."
  2739. )
  2740. repo_id = base_model_card_args.pop("repo_path_or_name")
  2741. # Deprecation warning will be sent after for repo_url and organization
  2742. repo_url = base_model_card_args.pop("repo_url", None)
  2743. organization = base_model_card_args.pop("organization", None)
  2744. if os.path.isdir(repo_id):
  2745. working_dir = repo_id
  2746. repo_id = repo_id.split(os.path.sep)[-1]
  2747. else:
  2748. working_dir = repo_id.split("/")[-1]
  2749. repo_id = self._create_repo(
  2750. repo_id, private=private, token=token, repo_url=repo_url, organization=organization
  2751. )
  2752. if use_temp_dir is None:
  2753. use_temp_dir = not os.path.isdir(working_dir)
  2754. with working_or_temp_dir(working_dir=working_dir, use_temp_dir=use_temp_dir) as work_dir:
  2755. files_timestamps = self._get_files_timestamps(work_dir)
  2756. # Save all files.
  2757. self.save_pretrained(work_dir, max_shard_size=max_shard_size)
  2758. if hasattr(self, "history") and hasattr(self, "create_model_card"):
  2759. # This is a Keras model and we might be able to fish out its History and make a model card out of it
  2760. base_model_card_args = {
  2761. "output_dir": work_dir,
  2762. "model_name": Path(repo_id).name,
  2763. }
  2764. base_model_card_args.update(base_model_card_args)
  2765. self.create_model_card(**base_model_card_args)
  2766. self._upload_modified_files(
  2767. work_dir,
  2768. repo_id,
  2769. files_timestamps,
  2770. commit_message=commit_message,
  2771. token=token,
  2772. create_pr=create_pr,
  2773. )
  2774. @classmethod
  2775. def register_for_auto_class(cls, auto_class="TFAutoModel"):
  2776. """
  2777. Register this class with a given auto class. This should only be used for custom models as the ones in the
  2778. library are already mapped with an auto class.
  2779. Args:
  2780. auto_class (`str` or `type`, *optional*, defaults to `"TFAutoModel"`):
  2781. The auto class to register this new model with.
  2782. """
  2783. if not isinstance(auto_class, str):
  2784. auto_class = auto_class.__name__
  2785. import transformers.models.auto as auto_module
  2786. if not hasattr(auto_module, auto_class):
  2787. raise ValueError(f"{auto_class} is not a valid auto class.")
  2788. cls._auto_class = auto_class
  2789. class TFConv1D(keras.layers.Layer):
  2790. """
  2791. 1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).
  2792. Basically works like a linear layer but the weights are transposed.
  2793. Args:
  2794. nf (`int`):
  2795. The number of output features.
  2796. nx (`int`):
  2797. The number of input features.
  2798. initializer_range (`float`, *optional*, defaults to 0.02):
  2799. The standard deviation to use to initialize the weights.
  2800. kwargs (`dict[str, Any]`, *optional*):
  2801. Additional keyword arguments passed along to the `__init__` of `keras.layers.Layer`.
  2802. """
  2803. def __init__(self, nf, nx, initializer_range=0.02, **kwargs):
  2804. super().__init__(**kwargs)
  2805. self.nf = nf
  2806. self.nx = nx
  2807. self.initializer_range = initializer_range
  2808. def build(self, input_shape):
  2809. if self.built:
  2810. return
  2811. self.built = True
  2812. self.weight = self.add_weight(
  2813. "weight", shape=[self.nx, self.nf], initializer=get_initializer(self.initializer_range)
  2814. )
  2815. self.bias = self.add_weight("bias", shape=[1, self.nf], initializer=tf.zeros_initializer())
  2816. def call(self, x):
  2817. bz, sl = shape_list(x)[:2]
  2818. x = tf.reshape(x, [-1, self.nx])
  2819. x = tf.matmul(x, self.weight) + self.bias
  2820. x = tf.reshape(x, [bz, sl, self.nf])
  2821. return x
  2822. class TFSharedEmbeddings(keras.layers.Layer):
  2823. r"""
  2824. Construct shared token embeddings.
  2825. The weights of the embedding layer is usually shared with the weights of the linear decoder when doing language
  2826. modeling.
  2827. Args:
  2828. vocab_size (`int`):
  2829. The size of the vocabulary, e.g., the number of unique tokens.
  2830. hidden_size (`int`):
  2831. The size of the embedding vectors.
  2832. initializer_range (`float`, *optional*):
  2833. The standard deviation to use when initializing the weights. If no value is provided, it will default to
  2834. \\(1/\sqrt{hidden\_size}\\).
  2835. kwargs (`dict[str, Any]`, *optional*):
  2836. Additional keyword arguments passed along to the `__init__` of `keras.layers.Layer`.
  2837. """
  2838. # TODO (joao): flagged for detection due to embeddings refactor
  2839. def __init__(self, vocab_size: int, hidden_size: int, initializer_range: float | None = None, **kwargs):
  2840. super().__init__(**kwargs)
  2841. self.vocab_size = vocab_size
  2842. self.hidden_size = hidden_size
  2843. self.initializer_range = hidden_size**-0.5 if initializer_range is None else initializer_range
  2844. warnings.warn(
  2845. "`TFSharedEmbeddings` is scheduled for deletion in v4.32, use `keras.layers.Embedding` instead.",
  2846. DeprecationWarning,
  2847. )
  2848. def build(self, input_shape):
  2849. """
  2850. Build shared token embedding layer Shared weights logic adapted from
  2851. https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
  2852. """
  2853. self.weight = self.add_weight(
  2854. "weight", shape=[self.vocab_size, self.hidden_size], initializer=get_initializer(self.initializer_range)
  2855. )
  2856. super().build(input_shape)
  2857. def get_config(self):
  2858. config = {
  2859. "vocab_size": self.vocab_size,
  2860. "hidden_size": self.hidden_size,
  2861. "initializer_range": self.initializer_range,
  2862. }
  2863. base_config = super().get_config()
  2864. return dict(list(base_config.items()) + list(config.items()))
  2865. def call(self, inputs: tf.Tensor, mode: str = "embedding") -> tf.Tensor:
  2866. """
  2867. Get token embeddings of inputs or decode final hidden state.
  2868. Args:
  2869. inputs (`tf.Tensor`):
  2870. In embedding mode, should be an int64 tensor with shape `[batch_size, length]`.
  2871. In linear mode, should be a float tensor with shape `[batch_size, length, hidden_size]`.
  2872. mode (`str`, defaults to `"embedding"`):
  2873. A valid value is either `"embedding"` or `"linear"`, the first one indicates that the layer should be
  2874. used as an embedding layer, the second one that the layer should be used as a linear decoder.
  2875. Returns:
  2876. `tf.Tensor`: In embedding mode, the output is a float32 embedding tensor, with shape `[batch_size, length,
  2877. embedding_size]`.
  2878. In linear mode, the output is a float32 with shape `[batch_size, length, vocab_size]`.
  2879. Raises:
  2880. ValueError: if `mode` is not valid.
  2881. Shared weights logic is adapted from
  2882. [here](https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24).
  2883. """
  2884. if mode == "embedding":
  2885. return self._embedding(inputs)
  2886. elif mode == "linear":
  2887. return self._linear(inputs)
  2888. else:
  2889. raise ValueError(f"mode {mode} is not valid.")
  2890. def _embedding(self, input_ids):
  2891. """Applies embedding based on inputs tensor."""
  2892. return tf.gather(self.weight, input_ids)
  2893. def _linear(self, inputs):
  2894. """
  2895. Computes logits by running inputs through a linear layer.
  2896. Args:
  2897. inputs: A float32 tensor with shape [..., hidden_size]
  2898. Returns:
  2899. float32 tensor with shape [..., vocab_size].
  2900. """
  2901. first_dims = shape_list(inputs)[:-1]
  2902. x = tf.reshape(inputs, [-1, self.hidden_size])
  2903. logits = tf.matmul(x, self.weight, transpose_b=True)
  2904. return tf.reshape(logits, first_dims + [self.vocab_size])
  2905. class TFSequenceSummary(keras.layers.Layer):
  2906. """
  2907. Compute a single vector summary of a sequence hidden states.
  2908. Args:
  2909. config ([`PretrainedConfig`]):
  2910. The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
  2911. config class of your model for the default values it uses):
  2912. - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are:
  2913. - `"last"` -- Take the last token hidden state (like XLNet)
  2914. - `"first"` -- Take the first token hidden state (like Bert)
  2915. - `"mean"` -- Take the mean of all tokens hidden states
  2916. - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
  2917. - `"attn"` -- Not implemented now, use multi-head attention
  2918. - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction.
  2919. - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes
  2920. (otherwise to `config.hidden_size`).
  2921. - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output,
  2922. another string or `None` will add no activation.
  2923. - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation.
  2924. - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation.
  2925. initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation to use to initialize the weights.
  2926. kwargs (`dict[str, Any]`, *optional*):
  2927. Additional keyword arguments passed along to the `__init__` of `keras.layers.Layer`.
  2928. """
  2929. def __init__(self, config: PretrainedConfig, initializer_range: float = 0.02, **kwargs):
  2930. super().__init__(**kwargs)
  2931. self.summary_type = config.summary_type if hasattr(config, "summary_use_proj") else "last"
  2932. if self.summary_type == "attn":
  2933. # We should use a standard multi-head attention module with absolute positional embedding for that.
  2934. # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
  2935. # We can probably just use the multi-head attention module of PyTorch >=1.1.0
  2936. raise NotImplementedError
  2937. self.has_summary = hasattr(config, "summary_use_proj") and config.summary_use_proj
  2938. if self.has_summary:
  2939. if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
  2940. num_classes = config.num_labels
  2941. else:
  2942. num_classes = config.hidden_size
  2943. self.summary = keras.layers.Dense(
  2944. num_classes, kernel_initializer=get_initializer(initializer_range), name="summary"
  2945. )
  2946. self.has_activation = False
  2947. activation_string = getattr(config, "summary_activation", None)
  2948. if activation_string is not None:
  2949. self.has_activation = True
  2950. self.activation = get_tf_activation(activation_string)
  2951. self.has_first_dropout = hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0
  2952. if self.has_first_dropout:
  2953. self.first_dropout = keras.layers.Dropout(config.summary_first_dropout)
  2954. self.has_last_dropout = hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0
  2955. if self.has_last_dropout:
  2956. self.last_dropout = keras.layers.Dropout(config.summary_last_dropout)
  2957. self.hidden_size = config.hidden_size
  2958. def call(self, inputs, cls_index=None, training=False):
  2959. if not isinstance(inputs, (dict, tuple, list)):
  2960. hidden_states = inputs
  2961. elif isinstance(inputs, (tuple, list)):
  2962. hidden_states = inputs[0]
  2963. cls_index = inputs[1] if len(inputs) > 1 else None
  2964. assert len(inputs) <= 2, "Too many inputs."
  2965. else:
  2966. hidden_states = inputs.get("hidden_states")
  2967. cls_index = inputs.get("cls_index", None)
  2968. if self.summary_type == "last":
  2969. output = hidden_states[:, -1]
  2970. elif self.summary_type == "first":
  2971. output = hidden_states[:, 0]
  2972. elif self.summary_type == "mean":
  2973. output = tf.reduce_mean(hidden_states, axis=1)
  2974. elif self.summary_type == "cls_index":
  2975. hidden_shape = shape_list(hidden_states) # e.g. [batch, num choices, seq length, hidden dims]
  2976. if cls_index is None:
  2977. cls_index = tf.fill(
  2978. hidden_shape[:-2], hidden_shape[-2] - 1
  2979. ) # A tensor full of shape [batch] or [batch, num choices] full of sequence length
  2980. cls_shape = shape_list(cls_index)
  2981. if len(cls_shape) <= len(hidden_shape) - 2:
  2982. cls_index = tf.expand_dims(cls_index, axis=-1)
  2983. # else:
  2984. # cls_index = cls_index[..., tf.newaxis]
  2985. # cls_index = cls_index.expand((-1,) * (cls_index.dim()-1) + (hidden_states.size(-1),))
  2986. # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
  2987. output = tf.gather(hidden_states, cls_index, batch_dims=len(hidden_shape) - 2)
  2988. output = tf.squeeze(
  2989. output, axis=len(hidden_shape) - 2
  2990. ) # shape of output: (batch, num choices, hidden_size)
  2991. elif self.summary_type == "attn":
  2992. raise NotImplementedError
  2993. if self.has_first_dropout:
  2994. output = self.first_dropout(output, training=training)
  2995. if self.has_summary:
  2996. output = self.summary(output)
  2997. if self.has_activation:
  2998. output = self.activation(output)
  2999. if self.has_last_dropout:
  3000. output = self.last_dropout(output, training=training)
  3001. return output
  3002. def build(self, input_shape):
  3003. if self.built:
  3004. return
  3005. self.built = True
  3006. if getattr(self, "summary", None) is not None:
  3007. with tf.name_scope("summary"):
  3008. self.summary.build(self.hidden_size)
  3009. def get_initializer(initializer_range: float = 0.02) -> keras.initializers.TruncatedNormal:
  3010. """
  3011. Creates a `keras.initializers.TruncatedNormal` with the given range.
  3012. Args:
  3013. initializer_range (*float*, defaults to 0.02): Standard deviation of the initializer range.
  3014. Returns:
  3015. `keras.initializers.TruncatedNormal`: The truncated normal initializer.
  3016. """
  3017. return keras.initializers.TruncatedNormal(stddev=initializer_range)