| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095309630973098309931003101310231033104310531063107310831093110311131123113311431153116311731183119312031213122312331243125312631273128312931303131313231333134313531363137313831393140314131423143314431453146314731483149315031513152315331543155315631573158315931603161316231633164316531663167316831693170317131723173317431753176317731783179318031813182318331843185318631873188318931903191319231933194319531963197319831993200320132023203320432053206320732083209321032113212321332143215321632173218321932203221322232233224322532263227322832293230323132323233323432353236323732383239324032413242324332443245324632473248324932503251325232533254325532563257325832593260326132623263326432653266326732683269327032713272327332743275327632773278327932803281328232833284328532863287328832893290329132923293329432953296329732983299330033013302330333043305330633073308330933103311331233133314331533163317331833193320332133223323332433253326332733283329333033313332333333343335333633373338333933403341334233433344334533463347334833493350335133523353335433553356335733583359336033613362336333643365336633673368336933703371337233733374337533763377337833793380338133823383338433853386338733883389339033913392339333943395339633973398339934003401340234033404340534063407340834093410341134123413341434153416341734183419342034213422342334243425342634273428342934303431343234333434343534363437343834393440344134423443344434453446344734483449345034513452345334543455345634573458345934603461346234633464346534663467346834693470347134723473347434753476347734783479348034813482348334843485348634873488348934903491349234933494349534963497349834993500350135023503350435053506350735083509351035113512351335143515351635173518351935203521352235233524352535263527352835293530353135323533353435353536353735383539354035413542354335443545354635473548354935503551355235533554355535563557355835593560356135623563356435653566356735683569357035713572357335743575357635773578357935803581358235833584358535863587358835893590359135923593359435953596359735983599360036013602360336043605360636073608360936103611361236133614361536163617361836193620362136223623362436253626362736283629363036313632363336343635363636373638363936403641364236433644364536463647364836493650365136523653365436553656365736583659366036613662366336643665366636673668366936703671367236733674367536763677367836793680368136823683368436853686368736883689369036913692369336943695369636973698369937003701370237033704370537063707370837093710371137123713371437153716371737183719372037213722372337243725372637273728372937303731373237333734373537363737373837393740374137423743374437453746374737483749375037513752375337543755375637573758375937603761376237633764376537663767376837693770377137723773377437753776377737783779378037813782378337843785378637873788378937903791379237933794379537963797379837993800380138023803380438053806380738083809381038113812381338143815381638173818381938203821382238233824382538263827382838293830383138323833383438353836383738383839384038413842384338443845384638473848384938503851385238533854385538563857385838593860386138623863 |
- # coding=utf-8
- # Copyright 2020 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
- # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import copy
- import inspect
- import os
- import warnings
- from dataclasses import dataclass
- from typing import TYPE_CHECKING, Any, Callable, Optional, Union
- import torch
- import torch.distributed as dist
- from packaging import version
- from torch import nn
- from ..cache_utils import (
- Cache,
- DynamicCache,
- EncoderDecoderCache,
- QuantizedCache,
- StaticCache,
- )
- from ..dynamic_module_utils import (
- check_python_requirements,
- get_cached_module_file,
- get_class_in_module,
- resolve_trust_remote_code,
- )
- from ..integrations.deepspeed import is_deepspeed_zero3_enabled
- from ..integrations.fsdp import is_fsdp_managed_module
- from ..masking_utils import create_masks_for_generate
- from ..pytorch_utils import isin_mps_friendly
- from ..tokenization_utils import ExtensionsTrie
- from ..utils import (
- ModelOutput,
- TransformersKwargs,
- is_accelerate_available,
- is_hqq_available,
- is_optimum_quanto_available,
- is_torchdynamo_exporting,
- logging,
- )
- from .candidate_generator import (
- AssistantVocabTranslatorCache,
- AssistedCandidateGenerator,
- AssistedCandidateGeneratorDifferentTokenizers,
- CandidateGenerator,
- EarlyExitCandidateGenerator,
- PromptLookupCandidateGenerator,
- UniversalSpeculativeDecodingGenerator,
- _prepare_attention_mask,
- _prepare_token_type_ids,
- )
- from .configuration_utils import (
- ALL_STATIC_CACHE_IMPLEMENTATIONS,
- DEPRECATED_STATIC_CACHE_IMPLEMENTATIONS,
- STATIC_CACHE_IMPLEMENTATIONS,
- GenerationConfig,
- GenerationMode,
- )
- from .continuous_batching import ContinuousMixin
- from .logits_process import (
- EncoderNoRepeatNGramLogitsProcessor,
- EncoderRepetitionPenaltyLogitsProcessor,
- EpsilonLogitsWarper,
- EtaLogitsWarper,
- ExponentialDecayLengthPenalty,
- ForcedBOSTokenLogitsProcessor,
- ForcedEOSTokenLogitsProcessor,
- InfNanRemoveLogitsProcessor,
- LogitNormalization,
- LogitsProcessorList,
- MinLengthLogitsProcessor,
- MinNewTokensLengthLogitsProcessor,
- MinPLogitsWarper,
- NoBadWordsLogitsProcessor,
- NoRepeatNGramLogitsProcessor,
- PrefixConstrainedLogitsProcessor,
- RepetitionPenaltyLogitsProcessor,
- SequenceBiasLogitsProcessor,
- SuppressTokensAtBeginLogitsProcessor,
- SuppressTokensLogitsProcessor,
- TemperatureLogitsWarper,
- TopKLogitsWarper,
- TopPLogitsWarper,
- TypicalLogitsWarper,
- UnbatchedClassifierFreeGuidanceLogitsProcessor,
- )
- from .stopping_criteria import (
- ConfidenceCriteria,
- EosTokenCriteria,
- MaxLengthCriteria,
- MaxTimeCriteria,
- StoppingCriteria,
- StoppingCriteriaList,
- StopStringCriteria,
- )
- if TYPE_CHECKING:
- from ..modeling_utils import PreTrainedModel
- from ..tokenization_utils_base import PreTrainedTokenizerBase
- from .streamers import BaseStreamer
- logger = logging.get_logger(__name__)
- if is_accelerate_available():
- from accelerate.hooks import AlignDevicesHook, add_hook_to_module
- # Variable names used to hold the cache at generation time
- ALL_CACHE_NAMES = [
- "past_key_values", # default
- "cache_params", # mamba-based models
- "state", # rwkv
- "mems", # xlnet
- "past_buckets_states", # reformer
- ]
- GENERATION_MODES_MAPPING = {
- GenerationMode.SAMPLE: "_sample",
- GenerationMode.GREEDY_SEARCH: "_sample",
- GenerationMode.BEAM_SEARCH: "_beam_search",
- GenerationMode.BEAM_SAMPLE: "_beam_search",
- GenerationMode.ASSISTED_GENERATION: "_assisted_decoding",
- # Deprecated methods
- GenerationMode.DOLA_GENERATION: "transformers-community/dola",
- GenerationMode.CONTRASTIVE_SEARCH: "transformers-community/contrastive-search",
- GenerationMode.GROUP_BEAM_SEARCH: "transformers-community/group-beam-search",
- GenerationMode.CONSTRAINED_BEAM_SEARCH: "transformers-community/constrained-beam-search",
- }
- @dataclass
- class GenerateDecoderOnlyOutput(ModelOutput):
- """
- Outputs of decoder-only generation models, when using non-beam methods.
- Args:
- sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
- if all batches finished early due to the `eos_token_id`.
- scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`):
- Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
- at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
- each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
- logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`):
- Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
- at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
- each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
- attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`):
- Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
- `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
- hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`):
- Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
- `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
- past_key_values (`Cache`, *optional*, returned when `use_cache=True`):
- Returns the model cache, used to speed up decoding. Different models have a different cache format, check
- the model's documentation. Usually, a [`~cache_utils.Cache`] instance.
- """
- sequences: torch.LongTensor
- scores: Optional[tuple[torch.FloatTensor]] = None
- logits: Optional[tuple[torch.FloatTensor]] = None
- attentions: Optional[tuple[tuple[torch.FloatTensor]]] = None
- hidden_states: Optional[tuple[tuple[torch.FloatTensor]]] = None
- past_key_values: Optional[Cache] = None
- @dataclass
- class GenerateEncoderDecoderOutput(ModelOutput):
- """
- Outputs of encoder-decoder generation models, when using non-beam methods.
- Args:
- sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`):
- The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
- if all batches finished early due to the `eos_token_id`.
- scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`):
- Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
- at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
- each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
- logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`):
- Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
- at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
- each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
- encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`):
- Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads,
- sequence_length, sequence_length)`.
- encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`):
- Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
- shape `(batch_size, sequence_length, hidden_size)`.
- decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`):
- Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
- `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
- cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`):
- Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
- `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
- decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`):
- Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
- `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
- past_key_values (`Cache`, *optional*, returned when `use_cache=True`):
- Returns the model cache, used to speed up decoding. Different models have a different cache format, check
- the model's documentation. Usually, a [`~cache_utils.Cache`] instance.
- """
- sequences: torch.LongTensor
- scores: Optional[tuple[torch.FloatTensor]] = None
- logits: Optional[tuple[torch.FloatTensor]] = None
- encoder_attentions: Optional[tuple[torch.FloatTensor]] = None
- encoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
- decoder_attentions: Optional[tuple[tuple[torch.FloatTensor]]] = None
- cross_attentions: Optional[tuple[tuple[torch.FloatTensor]]] = None
- decoder_hidden_states: Optional[tuple[tuple[torch.FloatTensor]]] = None
- past_key_values: Optional[Cache] = None
- @dataclass
- class GenerateBeamDecoderOnlyOutput(ModelOutput):
- """
- Outputs of decoder-only generation models, when using beam methods.
- Args:
- sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`):
- The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
- if all batches finished early due to the `eos_token_id`.
- sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True`):
- Final beam scores of the generated `sequences`.
- scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`):
- Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting
- of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
- Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token),
- with each tensor of shape `(batch_size*num_beams, config.vocab_size)`.
- logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`):
- Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
- at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
- each generated token), with each tensor of shape `(batch_size*num_beams, config.vocab_size)`.
- beam_indices (`torch.LongTensor`, *optional*, returned when `output_scores=True`):
- Beam indices of generated token id at each generation step. `torch.LongTensor` of shape
- `(batch_size*num_return_sequences, sequence_length)`.
- attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`):
- Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
- `torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`.
- hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`):
- Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
- `torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`.
- past_key_values (`Cache`, *optional*, returned when `use_cache=True`):
- Returns the model cache, used to speed up decoding. Different models have a different cache format, check
- the model's documentation. Usually, a [`~cache_utils.Cache`] instance.
- """
- sequences: torch.LongTensor
- sequences_scores: Optional[torch.FloatTensor] = None
- scores: Optional[tuple[torch.FloatTensor]] = None
- logits: Optional[tuple[torch.FloatTensor]] = None
- beam_indices: Optional[torch.LongTensor] = None
- attentions: Optional[tuple[tuple[torch.FloatTensor]]] = None
- hidden_states: Optional[tuple[tuple[torch.FloatTensor]]] = None
- past_key_values: Optional[Cache] = None
- @dataclass
- class GenerateBeamEncoderDecoderOutput(ModelOutput):
- """
- Outputs of encoder-decoder generation models, when using beam methods.
- Args:
- sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`):
- The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
- if all batches finished early due to the `eos_token_id`.
- sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True`):
- Final beam scores of the generated `sequences`.
- scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`):
- Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting
- of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
- Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token),
- with each tensor of shape `(batch_size*num_beams, config.vocab_size)`.
- logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`):
- Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
- at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
- each generated token), with each tensor of shape `(batch_size*num_beams, config.vocab_size)`.
- beam_indices (`torch.LongTensor`, *optional*, returned when `output_scores=True`):
- Beam indices of generated token id at each generation step. `torch.LongTensor` of shape
- `(batch_size*num_return_sequences, sequence_length)`.
- encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`):
- Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads,
- sequence_length, sequence_length)`.
- encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`):
- Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
- shape `(batch_size*num_beams*num_return_sequences, sequence_length, hidden_size)`.
- decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`):
- Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
- `torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, num_heads, generated_length,
- sequence_length)`.
- cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`):
- Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
- `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
- decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`):
- Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
- `torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`.
- past_key_values (`Cache`, *optional*, returned when `use_cache=True`):
- Returns the model cache, used to speed up decoding. Different models have a different cache format, check
- the model's documentation. Usually, a [`~cache_utils.Cache`] instance.
- """
- sequences: torch.LongTensor
- sequences_scores: Optional[torch.FloatTensor] = None
- scores: Optional[tuple[torch.FloatTensor]] = None
- logits: Optional[tuple[torch.FloatTensor]] = None
- beam_indices: Optional[torch.LongTensor] = None
- encoder_attentions: Optional[tuple[torch.FloatTensor]] = None
- encoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
- decoder_attentions: Optional[tuple[tuple[torch.FloatTensor]]] = None
- cross_attentions: Optional[tuple[tuple[torch.FloatTensor]]] = None
- decoder_hidden_states: Optional[tuple[tuple[torch.FloatTensor]]] = None
- past_key_values: Optional[Cache] = None
- # TODO (joao): remove the equivalent classes and typing shortcuts below in v5
- # Equivalent classes (kept for retrocompatibility purposes)
- GreedySearchDecoderOnlyOutput = GenerateDecoderOnlyOutput
- ContrastiveSearchDecoderOnlyOutput = GenerateDecoderOnlyOutput
- SampleDecoderOnlyOutput = GenerateDecoderOnlyOutput
- ContrastiveSearchEncoderDecoderOutput = GenerateEncoderDecoderOutput
- GreedySearchEncoderDecoderOutput = GenerateEncoderDecoderOutput
- SampleEncoderDecoderOutput = GenerateEncoderDecoderOutput
- BeamSearchDecoderOnlyOutput = GenerateBeamDecoderOnlyOutput
- BeamSampleDecoderOnlyOutput = GenerateBeamDecoderOnlyOutput
- BeamSearchEncoderDecoderOutput = GenerateBeamEncoderDecoderOutput
- BeamSampleEncoderDecoderOutput = GenerateBeamEncoderDecoderOutput
- GreedySearchOutput = Union[GreedySearchEncoderDecoderOutput, GreedySearchDecoderOnlyOutput]
- SampleOutput = Union[SampleEncoderDecoderOutput, SampleDecoderOnlyOutput]
- BeamSearchOutput = Union[BeamSearchEncoderDecoderOutput, BeamSearchDecoderOnlyOutput]
- BeamSampleOutput = Union[BeamSampleEncoderDecoderOutput, BeamSampleDecoderOnlyOutput]
- ContrastiveSearchOutput = Union[ContrastiveSearchEncoderDecoderOutput, ContrastiveSearchDecoderOnlyOutput]
- # Typing shortcuts
- GenerateNonBeamOutput = Union[GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput]
- GenerateBeamOutput = Union[GenerateBeamDecoderOnlyOutput, GenerateBeamEncoderDecoderOutput]
- GenerateOutput = Union[GenerateNonBeamOutput, GenerateBeamOutput]
- class GenerationMixin(ContinuousMixin):
- """
- A class containing all functions for auto-regressive text generation, to be used as a mixin in model classes.
- Inheriting from this class causes the model to have special generation-related behavior, such as loading a
- `GenerationConfig` at initialization time or ensuring `generate`-related tests are run in `transformers` CI.
- A model class should inherit from `GenerationMixin` to enable calling methods like `generate`, or when it
- has defined a custom `generate` method that relies on `GenerationMixin`, directly or indirectly, which
- approximately shares the same interface to public methods like `generate`. Three examples:
- - `LlamaForCausalLM` should inherit from `GenerationMixin` to enable calling `generate` and other public
- methods in the mixin;
- - `BlipForQuestionAnswering` has a custom `generate` method that approximately shares the same interface as
- `GenerationMixin.generate` (it has a few extra arguments, and the same output). That function also calls
- `GenerationMixin.generate` indirectly, through an inner model. As such, `BlipForQuestionAnswering` should
- inherit from `GenerationMixin` to benefit from all generation-related automation in our codebase;
- - `BarkModel` has a custom `generate` method and one of its inner models calls `GenerationMixin.generate`.
- However, its `generate` does not share the same interface as `GenerationMixin.generate`. In this case,
- `BarkModel` should NOT inherit from `GenerationMixin`, as it breaks the `generate` interface.
- The class exposes [`~generation.GenerationMixin.generate`], which can be used for:
- - *greedy decoding* if `num_beams=1` and `do_sample=False`
- - *multinomial sampling* if `num_beams=1` and `do_sample=True`
- - *beam-search decoding* if `num_beams>1` and `do_sample=False`
- - *beam-search multinomial sampling* if `num_beams>1` and `do_sample=True`
- - *assisted decoding* if `assistant_model` or `prompt_lookup_num_tokens` is passed to `.generate()`
- To learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).
- """
- def load_custom_generate(
- self,
- pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
- trust_remote_code: Optional[bool] = None,
- **kwargs,
- ) -> Callable:
- """
- Loads and returns a custom generate function, given a model repo.
- Args:
- pretrained_model_name_or_path (`str` or `os.PathLike`):
- Can be either:
- - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
- - A path to a *directory* containing model weights saved using
- [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
- trust_remote_code (`bool`, *optional*):
- Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
- should only be set to `True` for repositories you trust and in which you have read the code, as it will
- execute code present on the Hub on your local machine.
- **kwargs:
- Additional keyword arguments for remote code loading.
- Raises:
- OSError: If `pretrained_model_name_or_path` does not contain a `custom_generate` subdirectory.
- Returns:
- A callable that can be used to generate text.
- """
- # Fetches the generate.py file from the model repo. If it doesn't exist, a file in `.no_exist` cache directory
- # is created (preventing future hub requests), and an OSError is raised.
- try:
- module = get_cached_module_file(
- pretrained_model_name_or_path, module_file="custom_generate/generate.py", **kwargs
- )
- except OSError:
- raise OSError(
- f"`{pretrained_model_name_or_path}` does not contain a `custom_generate` subdirectory with a "
- "`generate.py` file, can't load the custom generate function."
- )
- # Handle opt-in `trust_remote_code` and related exceptions
- is_local_code = os.path.exists(pretrained_model_name_or_path)
- error_message = (
- f"The repository `{pretrained_model_name_or_path}` contains custom generation code that will override "
- "the default `generate` method."
- )
- resolve_trust_remote_code(
- trust_remote_code,
- pretrained_model_name_or_path,
- has_local_code=is_local_code,
- has_remote_code=not is_local_code,
- error_message=error_message,
- )
- # Load the custom generate function
- check_python_requirements(
- pretrained_model_name_or_path, requirements_file="custom_generate/requirements.txt", **kwargs
- )
- custom_generate_function = get_class_in_module("generate", module)
- return custom_generate_function
- def _cache_dependant_input_preparation(
- self,
- input_ids: torch.LongTensor,
- inputs_embeds: Optional[torch.FloatTensor],
- cache_position: Optional[torch.LongTensor],
- ) -> tuple[torch.FloatTensor, torch.LongTensor]:
- """
- Generic cache-dependent input preparation
- The code is put in a separate function to allow granular unit testing
- as it needs a different implementation to be exportable.
- If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
- - Exception 1: when passing input_embeds, input_ids may be missing entries
- - Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
- - Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
- - Exception 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and
- generate the first token for each sequence. Later use the generated Input ids for continuation.
- The current implementation does not rely on ``self`` and could be
- a class method. It is left as a standard method to be easily rewritten.
- """
- if is_torchdynamo_exporting():
- return self._cache_dependant_input_preparation_exporting(input_ids, inputs_embeds, cache_position)
- if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4
- inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
- elif (
- inputs_embeds is not None # Exception 1
- or (cache_position[-1] >= input_ids.shape[1]) # Exception 3
- ):
- input_ids = input_ids[:, -cache_position.shape[0] :]
- elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
- input_ids = input_ids[:, cache_position]
- return inputs_embeds, input_ids
- def _cache_dependant_input_preparation_exporting(
- self,
- input_ids: torch.LongTensor,
- inputs_embeds: Optional[torch.FloatTensor],
- cache_position: Optional[torch.LongTensor],
- ) -> tuple[torch.FloatTensor, torch.LongTensor]:
- """
- This method implements method ``_cache_dependant_input_preparation``
- with :func:`torch.cond` to make it exportable with :func:`torch.export.export`.
- The code is put in a separate function to allow granular unit testing.
- """
- if inputs_embeds is None:
- input_ids = input_ids[:, cache_position]
- else:
- # This is the code we need to implemented with torch.cond.
- # if input_ids.shape[1] == 0:
- # inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
- # else:
- # if cache_position[-1] >= input_ids.shape[1]:
- # input_ids = input_ids[:, -cache_position.shape[0] :]
- # else:
- # if input_ids.shape[1] != cache_position.shape[0]:
- # input_ids = input_ids[:, cache_position]
- # We need to clone the outputs to avoid aliasing.
- def branch_1(inputs_embeds, cache_position):
- return inputs_embeds[:, -cache_position.shape[0] :].clone()
- def branch_2(input_ids, cache_position):
- return input_ids[:, -cache_position.shape[0] :].clone()
- def branch_3(input_ids, cache_position):
- return input_ids[:, cache_position].clone()
- inputs_embeds, input_ids = torch.cond(
- input_ids.shape[1] == 0,
- (
- lambda input_ids, inputs_embeds, cache_position: (
- branch_1(inputs_embeds, cache_position),
- input_ids.clone(),
- )
- ),
- (
- lambda input_ids, inputs_embeds, cache_position: (
- inputs_embeds,
- torch.cond(
- cache_position[-1] >= input_ids.shape[1],
- branch_2,
- lambda input_ids, cache_position: (
- torch.cond(
- input_ids.shape[1] != cache_position.shape[0],
- branch_3,
- (lambda input_ids, cache_position: input_ids.clone()),
- [input_ids, cache_position],
- )
- ),
- [input_ids, cache_position],
- ),
- )
- ),
- [input_ids, inputs_embeds, cache_position],
- )
- return inputs_embeds, input_ids
- def prepare_inputs_for_generation(
- self,
- input_ids: torch.LongTensor,
- past_key_values: Optional[Cache] = None,
- attention_mask: Optional[torch.LongTensor] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- cache_position: Optional[torch.LongTensor] = None,
- **kwargs,
- ):
- """
- Prepare the model inputs for generation. Notable steps include selecting the correct input key and cloning when appropriate,
- creating position_ids from the attention_mask when missing, slicing inputs and converting 2D attention masks to 4D for
- compilable caches, and finally forwarding all additional keyword arguments unchanged to the model's forward pass.
- See the forward pass in the model documentation for expected arguments (different models might have different
- requirements for e.g. `past_key_values`). This function should work as is for most LLMs.
- """
- # 1. Handle BC:
- model_inputs = {}
- model_inputs["cache_position"] = cache_position
- # 2. Generic cache-dependent input preparation
- if past_key_values is not None:
- model_inputs["past_key_values"] = past_key_values
- # TODO (joao): handle the case where cache length == input_ids length. The function below results in an
- # exception because we get empty input_ids after slicing. In essence, we need to roll back the cache 1
- # token to recompute the logits for the first token to be generated (but not all caches support roll backs)
- inputs_embeds, input_ids = self._cache_dependant_input_preparation(
- input_ids, inputs_embeds, cache_position
- )
- # 3. Prepare base model inputs
- input_ids_key = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step for every prompt.
- if not self.config.is_encoder_decoder:
- if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]:
- model_inputs[input_ids_key] = None
- model_inputs["inputs_embeds"] = inputs_embeds
- else:
- # `clone` calls in this function ensure a consistent stride. See #32227
- model_inputs[input_ids_key] = input_ids.clone(memory_format=torch.contiguous_format)
- model_inputs["inputs_embeds"] = None
- else:
- model_inputs[input_ids_key] = input_ids.clone(memory_format=torch.contiguous_format)
- # 4. Create missing `position_ids` on the fly
- encoder_attention_mask = attention_mask if self.config.is_encoder_decoder else None
- attention_mask = (
- kwargs.pop("decoder_attention_mask", None) if self.config.is_encoder_decoder else attention_mask
- )
- attention_mask_key = "decoder_attention_mask" if self.config.is_encoder_decoder else "attention_mask"
- position_ids_key = "decoder_position_ids" if self.config.is_encoder_decoder else "position_ids"
- if (
- attention_mask is not None
- and kwargs.get(position_ids_key) is None
- and position_ids_key in set(inspect.signature(self.forward).parameters.keys())
- ):
- position_ids = attention_mask.long().cumsum(-1) - 1
- position_ids.masked_fill_(attention_mask == 0, 1)
- kwargs[position_ids_key] = position_ids # placed in kwargs for further processing (see below)
- # 5. Slice model inputs if it's an input that should have the same length as `input_ids`
- for model_input_name in ["position_ids", "token_type_ids", "decoder_position_ids"]:
- model_input = kwargs.get(model_input_name)
- if model_input is not None:
- if past_key_values is not None:
- current_input_length = (
- model_inputs["inputs_embeds"].shape[1]
- if model_inputs.get("inputs_embeds") is not None
- else model_inputs[input_ids_key].shape[1]
- )
- model_input = model_input[:, -current_input_length:]
- model_input = model_input.clone(memory_format=torch.contiguous_format)
- model_inputs[model_input_name] = model_input
- # 6. Create 4D attention mask is we are using a compilable cache (important for performant compiled forward
- # pass)
- if (
- isinstance(past_key_values, Cache)
- and past_key_values.is_compileable
- and attention_mask is not None
- and attention_mask.ndim == 2
- ):
- if not self.config.is_encoder_decoder and model_inputs["inputs_embeds"] is not None:
- batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
- else:
- batch_size, sequence_length = model_inputs[input_ids_key].shape[:2]
- # Create the causal mask with fixed shape in advance, to reduce recompilations. If the function to create
- # the 4D causal mask exists, it should be present in the base model (XXXModel class) or in its decoder.
- base_model = getattr(self, self.base_model_prefix, self)
- decoder = base_model.get_decoder() if hasattr(base_model, "get_decoder") else None
- causal_mask_creation_function = getattr(
- base_model, "_prepare_4d_causal_attention_mask_with_cache_position", None
- )
- if causal_mask_creation_function is None and decoder is not None: # it may be in the decoder
- causal_mask_creation_function = getattr(
- decoder, "_prepare_4d_causal_attention_mask_with_cache_position", None
- )
- # If it's not defined, it means the model uses the new general mask API
- if causal_mask_creation_function is None: # can't be found
- token_type_ids = model_inputs.get("token_type_ids")
- position_ids = model_inputs.get(position_ids_key)
- # Some models may overwrite the general one
- causal_mask_creation_function = getattr(self, "create_masks_for_generate", create_masks_for_generate)
- attention_mask = causal_mask_creation_function(
- config=self.config,
- # we only need batch size, seq_length and dtype here - we don't care about the values of the embeddings
- input_embeds=torch.empty((batch_size, sequence_length), dtype=self.dtype),
- attention_mask=attention_mask,
- cache_position=cache_position,
- past_key_values=past_key_values,
- position_ids=position_ids,
- token_type_ids=token_type_ids,
- )
- else:
- attention_mask = causal_mask_creation_function(
- attention_mask,
- sequence_length=sequence_length,
- target_length=past_key_values.get_max_cache_shape(),
- dtype=self.dtype,
- cache_position=cache_position,
- batch_size=batch_size,
- config=self.config,
- past_key_values=past_key_values,
- )
- if attention_mask is not None:
- model_inputs[attention_mask_key] = attention_mask
- if encoder_attention_mask is not None:
- model_inputs["attention_mask"] = encoder_attention_mask
- # 7. Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
- for key, value in kwargs.items():
- if key not in model_inputs:
- model_inputs[key] = value
- # 8. Remove unexpected `generate` inputs (TODO @joao: fix trainer and examples)
- model_inputs.pop("labels", None)
- return model_inputs
- def _prepare_model_inputs(
- self,
- inputs: Optional[torch.Tensor] = None,
- bos_token_id: Optional[torch.Tensor] = None,
- model_kwargs: Optional[dict[str, torch.Tensor]] = None,
- ) -> tuple[torch.Tensor, Optional[str], dict[str, torch.Tensor]]:
- """
- This function extracts the model-specific `inputs` for generation.
- """
- # 1. retrieve all kwargs that are non-None or non-model input related.
- # some encoder-decoder models have different names for model and encoder
- if (
- self.config.is_encoder_decoder
- and hasattr(self, "encoder")
- and self.encoder.main_input_name != self.main_input_name
- ):
- input_name = self.encoder.main_input_name
- else:
- input_name = self.main_input_name
- model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None or k != input_name}
- # 2. check whether model_input_name is passed as kwarg
- # if yes and `inputs` is None use kwarg inputs
- inputs_kwarg = model_kwargs.pop(input_name, None)
- if inputs_kwarg is not None and inputs is not None:
- raise ValueError(
- f"`inputs`: {inputs}` were passed alongside {input_name} which is not allowed. "
- f"Make sure to either pass {inputs} or {input_name}=..."
- )
- elif inputs_kwarg is not None:
- inputs = inputs_kwarg
- # 3. In the presence of `inputs_embeds` for text models:
- # - decoder-only models should complain if the user attempts to pass `inputs_embeds`, but the model
- # doesn't have its forwarding implemented. `inputs_embeds` is kept in `model_kwargs` and can coexist with
- # input_ids (`inputs_embeds` will be used in the 1st generation step, as opposed to `input_ids`)
- # - encoder-decoder models should complain if the user attempts to pass `inputs_embeds` and `input_ids`, and
- # pull the former to inputs. It will be used in place of `input_ids` to get the encoder hidden states.
- if input_name == "input_ids" and "inputs_embeds" in model_kwargs:
- if model_kwargs["inputs_embeds"] is None:
- model_kwargs.pop("inputs_embeds")
- elif not self.config.is_encoder_decoder:
- has_inputs_embeds_forwarding = "inputs_embeds" in set(
- inspect.signature(self.prepare_inputs_for_generation).parameters.keys()
- )
- if not has_inputs_embeds_forwarding:
- raise ValueError(
- f"You passed `inputs_embeds` to `.generate()`, but the model class {self.__class__.__name__} "
- "doesn't have its forwarding implemented. See the GPT2 implementation for an example "
- "(https://github.com/huggingface/transformers/pull/21405), and feel free to open a PR with it!"
- )
- # In this case, `input_ids` is moved to the `model_kwargs`, so a few automations (like the creation of
- # the attention mask) can rely on the actual model input.
- model_kwargs["input_ids"] = self._maybe_initialize_input_ids_for_generation(
- inputs, bos_token_id, model_kwargs=model_kwargs
- )
- inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"
- else:
- if inputs is not None:
- raise ValueError("You passed `inputs_embeds` and `input_ids` to `.generate()`. Please pick one.")
- inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"
- # 4. if `inputs` is still None, try to create `input_ids` from BOS token
- inputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs)
- return inputs, input_name, model_kwargs
- def _maybe_initialize_input_ids_for_generation(
- self,
- inputs: Optional[torch.Tensor] = None,
- bos_token_id: Optional[torch.Tensor] = None,
- model_kwargs: Optional[dict[str, torch.Tensor]] = None,
- ) -> torch.LongTensor:
- """Initializes input ids for generation, if necessary."""
- if inputs is not None:
- return inputs
- encoder_outputs = model_kwargs.get("encoder_outputs")
- if self.config.is_encoder_decoder and encoder_outputs is not None:
- # make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding
- shape = encoder_outputs.last_hidden_state.size()[:-1]
- return torch.ones(shape, dtype=torch.long, device=self.device) * -100
- # If there is some tensor in `model_kwargs`, we can infer the batch size from it. This is helpful with
- # soft-prompting or in multimodal implementations built on top of decoder-only language models.
- batch_size = 1
- for value in model_kwargs.values():
- if isinstance(value, torch.Tensor):
- batch_size = value.shape[0]
- break
- if "inputs_embeds" in model_kwargs:
- return torch.ones((batch_size, 0), dtype=torch.long, device=self.device)
- if bos_token_id is None:
- raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.")
- return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id
- def _prepare_attention_mask_for_generation(
- self,
- inputs_tensor: torch.Tensor,
- generation_config: GenerationConfig,
- model_kwargs: dict[str, Any],
- ) -> torch.LongTensor:
- pad_token_id = generation_config._pad_token_tensor
- eos_token_id = generation_config._eos_token_tensor
- # `input_ids` may be present in the model kwargs, instead of being the main input (e.g. multimodal model)
- if "input_ids" in model_kwargs and model_kwargs["input_ids"].shape[1] > 0:
- inputs_tensor = model_kwargs["input_ids"]
- # No information for attention mask inference -> return default attention mask
- default_attention_mask = torch.ones(inputs_tensor.shape[:2], dtype=torch.long, device=inputs_tensor.device)
- if pad_token_id is None:
- return default_attention_mask
- is_input_ids = len(inputs_tensor.shape) == 2 and inputs_tensor.dtype in [torch.int, torch.long]
- if not is_input_ids:
- return default_attention_mask
- is_pad_token_in_inputs = (pad_token_id is not None) and (
- isin_mps_friendly(elements=inputs_tensor, test_elements=pad_token_id).any()
- )
- is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ~(
- isin_mps_friendly(elements=eos_token_id, test_elements=pad_token_id).any()
- )
- can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id
- attention_mask_from_padding = inputs_tensor.ne(pad_token_id).long()
- attention_mask = (
- attention_mask_from_padding * can_infer_attention_mask + default_attention_mask * ~can_infer_attention_mask
- )
- return attention_mask
- def _prepare_encoder_decoder_kwargs_for_generation(
- self,
- inputs_tensor: torch.Tensor,
- model_kwargs,
- model_input_name: Optional[str],
- generation_config: GenerationConfig,
- ) -> dict[str, Any]:
- # 1. get encoder
- encoder = self.get_encoder()
- # Compatibility with Accelerate big model inference: we need the encoder to outputs stuff on the same device
- # as the inputs.
- if hasattr(self, "hf_device_map"):
- if hasattr(encoder, "_hf_hook"):
- encoder._hf_hook.io_same_device = True
- else:
- add_hook_to_module(encoder, AlignDevicesHook(io_same_device=True))
- # 2. Prepare encoder args and encoder kwargs from model kwargs and generation config.
- irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
- encoder_kwargs = {
- argument: value
- for argument, value in model_kwargs.items()
- if not any(argument.startswith(p) for p in irrelevant_prefix)
- }
- encoder_signature = set(inspect.signature(encoder.forward).parameters)
- encoder_accepts_wildcard = "kwargs" in encoder_signature or "model_kwargs" in encoder_signature
- if not encoder_accepts_wildcard:
- encoder_kwargs = {
- argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature
- }
- encoder_kwargs["output_attentions"] = generation_config.output_attentions
- encoder_kwargs["output_hidden_states"] = generation_config.output_hidden_states
- # 3. make sure that encoder returns `ModelOutput`
- model_input_name = model_input_name if model_input_name is not None else self.main_input_name
- encoder_kwargs["return_dict"] = True
- encoder_kwargs[model_input_name] = inputs_tensor
- model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs) # type: ignore
- return model_kwargs
- def _prepare_decoder_input_ids_for_generation(
- self,
- batch_size: int,
- model_input_name: str,
- model_kwargs: dict[str, torch.Tensor],
- decoder_start_token_id: torch.Tensor,
- device: Optional[torch.device] = None,
- ) -> tuple[torch.LongTensor, dict[str, torch.Tensor]]:
- """Prepares `decoder_input_ids` for generation with encoder-decoder models"""
- # 1. Check whether the user has defined `decoder_input_ids` manually. To facilitate in terms of input naming,
- # we also allow the user to pass it under `input_ids`, if the encoder does not use it as the main input.
- if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
- decoder_input_ids = model_kwargs.pop("decoder_input_ids")
- elif "input_ids" in model_kwargs and model_input_name != "input_ids":
- decoder_input_ids = model_kwargs.pop("input_ids")
- else:
- decoder_input_ids = None
- # 2. `decoder_start_token_id` must have shape (batch_size, 1)
- if device is None:
- device = self.device
- if decoder_start_token_id.ndim == 1:
- if decoder_start_token_id.shape[0] != batch_size:
- raise ValueError(
- f"`decoder_start_token_id` expected to have length {batch_size} but got {decoder_start_token_id.shape[0]}"
- )
- decoder_start_token_id = decoder_start_token_id.view(-1, 1)
- else:
- decoder_start_token_id = (
- torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id
- )
- # 3. Encoder-decoder models expect the `decoder_input_ids` to start with a special token. Let's ensure that.
- # no user input -> use decoder_start_token_id as decoder_input_ids
- if decoder_input_ids is None:
- decoder_input_ids = decoder_start_token_id
- # exception: Donut checkpoints have task-specific decoder starts and don't expect a BOS token. Note that the
- # original checkpoints can't be detected through `self.__class__.__name__.lower()`, needing custom logic.
- # See: https://github.com/huggingface/transformers/pull/31470
- elif "donut" in self.__class__.__name__.lower() or (
- self.config.model_type == "vision-encoder-decoder" and "donut" in self.config.encoder.model_type.lower()
- ):
- pass
- elif self.config.model_type == "whisper":
- pass
- # user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust
- # decoder_attention_mask if provided)
- elif (decoder_input_ids[:, 0] != decoder_start_token_id[:, 0]).all().item():
- decoder_input_ids = torch.cat([decoder_start_token_id, decoder_input_ids], dim=-1)
- if "decoder_attention_mask" in model_kwargs:
- decoder_attention_mask = model_kwargs["decoder_attention_mask"]
- decoder_attention_mask = torch.cat(
- (torch.ones_like(decoder_attention_mask)[:, :1], decoder_attention_mask),
- dim=-1,
- )
- model_kwargs["decoder_attention_mask"] = decoder_attention_mask
- return decoder_input_ids, model_kwargs
- @staticmethod
- def _expand_inputs_for_generation(
- expand_size: int = 1,
- is_encoder_decoder: bool = False,
- input_ids: Optional[torch.LongTensor] = None,
- **model_kwargs,
- ) -> tuple[torch.LongTensor, dict[str, Any]]:
- """Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]"""
- # Do not call torch.repeat_interleave if expand_size is 1 because it clones
- # the input tensor and thus requires more memory although no change is applied
- if expand_size == 1:
- return input_ids, model_kwargs
- def _expand_dict_for_generation(dict_to_expand):
- for key in dict_to_expand:
- if (
- key != "cache_position"
- and dict_to_expand[key] is not None
- and isinstance(dict_to_expand[key], torch.Tensor)
- ):
- dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
- return dict_to_expand
- if input_ids is not None:
- input_ids = input_ids.repeat_interleave(expand_size, dim=0)
- model_kwargs = _expand_dict_for_generation(model_kwargs)
- if is_encoder_decoder:
- if model_kwargs.get("encoder_outputs") is None:
- raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
- model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"])
- return input_ids, model_kwargs
- def _update_model_kwargs_for_generation(
- self,
- outputs: ModelOutput,
- model_kwargs: dict[str, Any],
- is_encoder_decoder: bool = False,
- num_new_tokens: int = 1,
- ) -> dict[str, Any]:
- # update past_key_values keeping its naming used in model code
- for possible_cache_name in ALL_CACHE_NAMES:
- if possible_cache_name in outputs:
- # TODO (joao): remove output/input mismatch when these old models (xlnet, reformer) are deprecated
- if possible_cache_name in ("past_buckets_states", "mems"):
- cache_name = "past_key_values"
- else:
- cache_name = possible_cache_name
- model_kwargs[cache_name] = getattr(outputs, possible_cache_name)
- break
- # update token_type_ids with last value
- if "token_type_ids" in model_kwargs:
- token_type_ids = model_kwargs["token_type_ids"]
- model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
- if not is_encoder_decoder:
- # update attention mask
- if "attention_mask" in model_kwargs:
- attention_mask = model_kwargs["attention_mask"]
- model_kwargs["attention_mask"] = torch.cat(
- [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
- )
- else:
- # update decoder attention mask
- if "decoder_attention_mask" in model_kwargs:
- decoder_attention_mask = model_kwargs["decoder_attention_mask"]
- model_kwargs["decoder_attention_mask"] = torch.cat(
- [decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], 1))],
- dim=-1,
- )
- if model_kwargs.get("use_cache", True):
- model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens
- else:
- past_positions = model_kwargs.pop("cache_position")
- new_positions = torch.arange(
- past_positions[-1] + 1, past_positions[-1] + num_new_tokens + 1, dtype=past_positions.dtype
- ).to(past_positions.device)
- model_kwargs["cache_position"] = torch.cat((past_positions, new_positions))
- return model_kwargs
- def _get_candidate_generator(
- self,
- generation_config: GenerationConfig,
- input_ids: torch.LongTensor,
- inputs_tensor: torch.Tensor,
- logits_processor: LogitsProcessorList,
- model_kwargs: dict[str, Any],
- assistant_model: Optional["PreTrainedModel"] = None,
- target_tokenizer: Optional["PreTrainedTokenizerBase"] = None,
- assistant_tokenizer: Optional["PreTrainedTokenizerBase"] = None,
- ) -> CandidateGenerator:
- """
- Returns the candidate generator to be used in `assisted_generation`
- """
- different_tokenizers = all(v is not None for v in (assistant_model, target_tokenizer, assistant_tokenizer))
- if generation_config.assistant_early_exit is not None:
- candidate_generator = EarlyExitCandidateGenerator(
- input_ids=input_ids,
- assistant_model=self,
- generation_config=generation_config,
- model_kwargs=model_kwargs,
- inputs_tensor=inputs_tensor,
- logits_processor=logits_processor,
- )
- elif generation_config.prompt_lookup_num_tokens is not None:
- candidate_generator = PromptLookupCandidateGenerator(
- eos_token_id=generation_config._eos_token_tensor,
- num_output_tokens=generation_config.prompt_lookup_num_tokens,
- max_matching_ngram_size=generation_config.max_matching_ngram_size or 2,
- max_length=generation_config.max_length,
- logits_processor=logits_processor,
- vocab_size=self.config.get_text_config().vocab_size,
- )
- elif different_tokenizers:
- if generation_config.do_sample is True:
- atm_translator = AssistantVocabTranslatorCache.get_translator(
- target_tokenizer,
- assistant_tokenizer,
- self.config.get_text_config().vocab_size,
- assistant_model=assistant_model,
- assistant_prune_lm_head=True, # prune LM head of assistant model
- )
- # Since we prune the LM head, we cannot use the repetition penalty on the assistant model due to mismatches between token ids and logits index
- assistant_model.generation_config.repetition_penalty = None
- candidate_generator = UniversalSpeculativeDecodingGenerator(
- input_ids=input_ids,
- assistant_model=assistant_model,
- generation_config=generation_config,
- model_kwargs=model_kwargs,
- inputs_tensor=inputs_tensor,
- logits_processor=logits_processor,
- target_tokenizer=target_tokenizer,
- assistant_tokenizer=assistant_tokenizer,
- atm_translator=atm_translator,
- )
- elif generation_config.do_sample is False:
- candidate_generator = AssistedCandidateGeneratorDifferentTokenizers(
- input_ids=input_ids,
- assistant_model=assistant_model,
- generation_config=generation_config,
- model_kwargs=model_kwargs,
- inputs_tensor=inputs_tensor,
- logits_processor=logits_processor,
- target_tokenizer=target_tokenizer,
- assistant_tokenizer=assistant_tokenizer,
- )
- else:
- raise ValueError(
- f"Invalid value for `do_sample`: expected a boolean, got {type(generation_config.do_sample).__name__}"
- )
- else:
- candidate_generator = AssistedCandidateGenerator(
- input_ids=input_ids,
- assistant_model=assistant_model,
- generation_config=generation_config,
- model_kwargs=model_kwargs,
- inputs_tensor=inputs_tensor,
- logits_processor=logits_processor,
- )
- return candidate_generator
- def _get_logits_processor(
- self,
- generation_config: GenerationConfig,
- input_ids_seq_length: Optional[int] = None,
- encoder_input_ids: Optional[torch.LongTensor] = None,
- prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None,
- logits_processor: Optional[LogitsProcessorList] = None,
- device: Optional[str] = None,
- model_kwargs: Optional[dict[str, Any]] = None,
- negative_prompt_ids: Optional[torch.Tensor] = None,
- negative_prompt_attention_mask: Optional[torch.Tensor] = None,
- ) -> LogitsProcessorList:
- """
- This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsProcessor`]
- instances used to modify the scores of the language model head.
- """
- # instantiate processors list
- processors = LogitsProcessorList()
- if logits_processor is None:
- logits_processor = []
- if generation_config.guidance_scale is not None and generation_config.guidance_scale != 1:
- processors.append(
- UnbatchedClassifierFreeGuidanceLogitsProcessor(
- generation_config.guidance_scale,
- self,
- unconditional_ids=negative_prompt_ids,
- unconditional_attention_mask=negative_prompt_attention_mask,
- use_cache=generation_config.use_cache,
- )
- )
- if generation_config.sequence_bias is not None:
- processors.append(SequenceBiasLogitsProcessor(sequence_bias=generation_config.sequence_bias))
- if (
- generation_config.encoder_repetition_penalty is not None
- and generation_config.encoder_repetition_penalty != 1.0
- ):
- if len(encoder_input_ids.shape) == 2:
- processors.append(
- EncoderRepetitionPenaltyLogitsProcessor(
- penalty=generation_config.encoder_repetition_penalty,
- encoder_input_ids=encoder_input_ids,
- )
- )
- else:
- warnings.warn(
- "Passing `encoder_repetition_penalty` requires some form of `input_ids` to be passed to "
- "`generate`, ignoring the argument.",
- UserWarning,
- )
- if generation_config.repetition_penalty is not None and generation_config.repetition_penalty != 1.0:
- processors.append(RepetitionPenaltyLogitsProcessor(penalty=generation_config.repetition_penalty))
- if generation_config.no_repeat_ngram_size is not None and generation_config.no_repeat_ngram_size > 0:
- processors.append(NoRepeatNGramLogitsProcessor(generation_config.no_repeat_ngram_size))
- if (
- generation_config.encoder_no_repeat_ngram_size is not None
- and generation_config.encoder_no_repeat_ngram_size > 0
- ):
- if len(encoder_input_ids.shape) == 2:
- processors.append(
- EncoderNoRepeatNGramLogitsProcessor(
- generation_config.encoder_no_repeat_ngram_size,
- encoder_input_ids,
- )
- )
- else:
- warnings.warn(
- "Passing `encoder_no_repeat_ngram_size` requires some form of `input_ids` to be passed to "
- "`generate`, ignoring the argument.",
- UserWarning,
- )
- if generation_config.bad_words_ids is not None:
- processors.append(
- NoBadWordsLogitsProcessor(
- generation_config.bad_words_ids,
- generation_config._eos_token_tensor,
- )
- )
- if (
- generation_config.min_length is not None
- and getattr(generation_config, "_eos_token_tensor", None) is not None
- and generation_config.min_length > 0
- ):
- processors.append(
- MinLengthLogitsProcessor(
- generation_config.min_length,
- generation_config._eos_token_tensor,
- device=device,
- )
- )
- if (
- generation_config.min_new_tokens is not None
- and getattr(generation_config, "_eos_token_tensor", None) is not None
- and generation_config.min_new_tokens > 0
- ):
- processors.append(
- MinNewTokensLengthLogitsProcessor(
- input_ids_seq_length,
- generation_config.min_new_tokens,
- generation_config._eos_token_tensor,
- device=device,
- )
- )
- if prefix_allowed_tokens_fn is not None:
- processors.append(
- PrefixConstrainedLogitsProcessor(
- prefix_allowed_tokens_fn,
- generation_config.num_beams,
- )
- )
- if generation_config.forced_bos_token_id is not None:
- processors.append(
- ForcedBOSTokenLogitsProcessor(
- generation_config.forced_bos_token_id,
- )
- )
- if generation_config.forced_eos_token_id is not None:
- processors.append(
- ForcedEOSTokenLogitsProcessor(
- generation_config.max_length,
- generation_config.forced_eos_token_id,
- device=device,
- )
- )
- if generation_config.remove_invalid_values is True:
- processors.append(InfNanRemoveLogitsProcessor())
- if generation_config.exponential_decay_length_penalty is not None:
- processors.append(
- ExponentialDecayLengthPenalty(
- generation_config.exponential_decay_length_penalty,
- generation_config._eos_token_tensor,
- input_ids_seq_length,
- )
- )
- if generation_config.suppress_tokens is not None:
- processors.append(
- SuppressTokensLogitsProcessor(
- generation_config.suppress_tokens,
- device=device,
- )
- )
- if generation_config.begin_suppress_tokens is not None:
- begin_index = input_ids_seq_length
- begin_index = (
- begin_index
- if (input_ids_seq_length > 1 or generation_config.forced_bos_token_id is None)
- else begin_index + 1
- )
- processors.append(
- SuppressTokensAtBeginLogitsProcessor(
- generation_config.begin_suppress_tokens,
- begin_index,
- device=device,
- )
- )
- # TODO (joao): find a strategy to specify the order of the processors
- processors = self._merge_criteria_processor_list(processors, logits_processor)
- # Processors previously known as `LogitsWarpers`, only applied with sampling strategies
- if generation_config.do_sample:
- # In beam methods, we need to keep at least one non-eos token to explore continuations that might have a
- # better score (i.e. keep len(list(generation_config._eos_token_tensor)) + 1)
- if generation_config.num_beams > 1:
- if isinstance(generation_config._eos_token_tensor, list):
- min_tokens_to_keep = len(generation_config._eos_token_tensor) + 1
- elif isinstance(generation_config._eos_token_tensor, torch.Tensor):
- min_tokens_to_keep = generation_config._eos_token_tensor.shape[0] + 1
- else:
- min_tokens_to_keep = 2
- else:
- min_tokens_to_keep = 1
- # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
- # all samplers can be found in `generation_utils_samplers.py`
- if generation_config.temperature is not None and generation_config.temperature != 1.0:
- processors.append(TemperatureLogitsWarper(generation_config.temperature))
- if generation_config.top_k is not None and generation_config.top_k != 0:
- processors.append(
- TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep)
- )
- if generation_config.top_p is not None and generation_config.top_p < 1.0:
- processors.append(
- TopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=min_tokens_to_keep)
- )
- if generation_config.min_p is not None:
- # Applied after temperature scaling (see https://github.com/ggerganov/llama.cpp/pull/3841#issuecomment-2073826084)
- processors.append(
- MinPLogitsWarper(min_p=generation_config.min_p, min_tokens_to_keep=min_tokens_to_keep)
- )
- if generation_config.typical_p is not None and generation_config.typical_p < 1.0:
- processors.append(
- TypicalLogitsWarper(mass=generation_config.typical_p, min_tokens_to_keep=min_tokens_to_keep)
- )
- if generation_config.epsilon_cutoff is not None and 0.0 < generation_config.epsilon_cutoff < 1.0:
- processors.append(
- EpsilonLogitsWarper(
- epsilon=generation_config.epsilon_cutoff, min_tokens_to_keep=min_tokens_to_keep
- )
- )
- if generation_config.eta_cutoff is not None and 0.0 < generation_config.eta_cutoff < 1.0:
- processors.append(
- EtaLogitsWarper(
- epsilon=generation_config.eta_cutoff, min_tokens_to_keep=min_tokens_to_keep, device=device
- )
- )
- # Watermarking should be after all logits processing is finished (see #34630)
- if generation_config.watermarking_config is not None:
- processors.append(
- generation_config.watermarking_config.construct_processor(
- self.config.get_text_config().vocab_size, device
- )
- )
- # `LogitNormalization` should always be the last logit processor, when present
- if generation_config.renormalize_logits is True:
- processors.append(LogitNormalization())
- return processors
- def _get_stopping_criteria(
- self,
- generation_config: GenerationConfig,
- stopping_criteria: Optional[StoppingCriteriaList],
- tokenizer: Optional["PreTrainedTokenizerBase"] = None,
- ) -> StoppingCriteriaList:
- criteria = StoppingCriteriaList()
- if generation_config.max_length is not None:
- max_position_embeddings = getattr(self.config, "max_position_embeddings", None)
- criteria.append(
- MaxLengthCriteria(
- max_length=generation_config.max_length,
- max_position_embeddings=max_position_embeddings,
- )
- )
- if generation_config.max_time is not None:
- criteria.append(MaxTimeCriteria(max_time=generation_config.max_time))
- if generation_config.stop_strings is not None:
- if tokenizer is None:
- raise ValueError(
- "There are one or more stop strings, either in the arguments to `generate` or in the "
- "model's generation config, but we could not locate a tokenizer. When generating with "
- "stop strings, you must pass the model's tokenizer to the `tokenizer` argument of `generate`."
- )
- criteria.append(StopStringCriteria(stop_strings=generation_config.stop_strings, tokenizer=tokenizer))
- if generation_config._eos_token_tensor is not None:
- criteria.append(EosTokenCriteria(eos_token_id=generation_config._eos_token_tensor))
- if (
- generation_config.is_assistant
- and generation_config.assistant_confidence_threshold is not None
- and generation_config.assistant_confidence_threshold > 0
- ):
- criteria.append(
- ConfidenceCriteria(assistant_confidence_threshold=generation_config.assistant_confidence_threshold)
- )
- criteria = self._merge_criteria_processor_list(criteria, stopping_criteria)
- return criteria
- def _merge_criteria_processor_list(
- self,
- default_list: Union[LogitsProcessorList, StoppingCriteriaList],
- custom_list: Union[LogitsProcessorList, StoppingCriteriaList],
- ) -> Union[LogitsProcessorList, StoppingCriteriaList]:
- """
- Merge user-defined processors/criteria with the ones instantiated inside `generate`. In case the same
- processor/criteria is present on both lists, use the user-defined one.
- (Note: up to v4.49.0, this function threw an exception is the same logit processor was found twice.)
- """
- if len(custom_list) == 0:
- return default_list
- final_list = type(default_list)()
- for default in default_list:
- using_custom = False
- for custom in custom_list:
- if type(custom) is type(default):
- object_type = "stopping criteria" if isinstance(custom, StoppingCriteria) else "logits processor"
- logger.warning_once(
- f"A custom {object_type} of type {type(custom)} has been passed to `.generate()`, but it "
- f"was also created in `.generate()`, given its parameterization. The custom {type(custom)} "
- f"will take precedence. Please check the docstring of {type(custom)} to see related "
- "`.generate()` flags."
- )
- final_list.append(custom)
- using_custom = True
- break
- if not using_custom:
- final_list.append(default)
- for custom in custom_list:
- if custom not in final_list:
- final_list.append(custom)
- return final_list
- def compute_transition_scores(
- self,
- sequences: torch.Tensor,
- scores: tuple[torch.Tensor],
- beam_indices: Optional[torch.Tensor] = None,
- normalize_logits: bool = False,
- ) -> torch.Tensor:
- """
- Computes the transition scores of sequences given the generation scores (and beam indices, if beam search was
- used). This is a convenient method to quickly obtain the scores of the selected tokens at generation time.
- Parameters:
- sequences (`torch.LongTensor`):
- The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or
- shorter if all batches finished early due to the `eos_token_id`.
- scores (`tuple(torch.FloatTensor)`):
- Transition scores for each vocabulary token at each generation step. Beam transition scores consisting
- of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
- Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token),
- with each tensor of shape `(batch_size*num_beams, config.vocab_size)`.
- beam_indices (`torch.LongTensor`, *optional*):
- Beam indices of generated token id at each generation step. `torch.LongTensor` of shape
- `(batch_size*num_return_sequences, sequence_length)`. Only required if a `num_beams>1` at
- generate-time.
- normalize_logits (`bool`, *optional*, defaults to `False`):
- Whether to normalize the logits (which, for legacy reasons, may be unnormalized).
- Return:
- `torch.Tensor`: A `torch.Tensor` of shape `(batch_size*num_return_sequences, sequence_length)` containing
- the transition scores (logits)
- Examples:
- ```python
- >>> from transformers import GPT2Tokenizer, AutoModelForCausalLM
- >>> import numpy as np
- >>> tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
- >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
- >>> tokenizer.pad_token_id = tokenizer.eos_token_id
- >>> inputs = tokenizer(["Today is"], return_tensors="pt")
- >>> # Example 1: Print the scores for each token generated with Greedy Search
- >>> outputs = model.generate(**inputs, max_new_tokens=5, return_dict_in_generate=True, output_scores=True)
- >>> transition_scores = model.compute_transition_scores(
- ... outputs.sequences, outputs.scores, normalize_logits=True
- ... )
- >>> # input_length is the length of the input prompt for decoder-only models, like the GPT family, and 1 for
- >>> # encoder-decoder models, like BART or T5.
- >>> input_length = 1 if model.config.is_encoder_decoder else inputs.input_ids.shape[1]
- >>> generated_tokens = outputs.sequences[:, input_length:]
- >>> for tok, score in zip(generated_tokens[0], transition_scores[0]):
- ... # | token | token string | log probability | probability
- ... print(f"| {tok:5d} | {tokenizer.decode(tok):8s} | {score.numpy():.3f} | {np.exp(score.numpy()):.2%}")
- | 262 | the | -1.414 | 24.33%
- | 1110 | day | -2.609 | 7.36%
- | 618 | when | -2.010 | 13.40%
- | 356 | we | -1.859 | 15.58%
- | 460 | can | -2.508 | 8.14%
- >>> # Example 2: Reconstruct the sequence scores from Beam Search
- >>> outputs = model.generate(
- ... **inputs,
- ... max_new_tokens=5,
- ... num_beams=4,
- ... num_return_sequences=4,
- ... return_dict_in_generate=True,
- ... output_scores=True,
- ... )
- >>> transition_scores = model.compute_transition_scores(
- ... outputs.sequences, outputs.scores, outputs.beam_indices, normalize_logits=False
- ... )
- >>> # If you sum the generated tokens' scores and apply the length penalty, you'll get the sequence scores.
- >>> # Tip 1: recomputing the scores is only guaranteed to match with `normalize_logits=False`. Depending on the
- >>> # use case, you might want to recompute it with `normalize_logits=True`.
- >>> # Tip 2: the output length does NOT include the input length
- >>> output_length = np.sum(transition_scores.numpy() < 0, axis=1)
- >>> length_penalty = model.generation_config.length_penalty
- >>> reconstructed_scores = transition_scores.sum(axis=1) / (output_length**length_penalty)
- >>> print(np.allclose(outputs.sequences_scores, reconstructed_scores))
- True
- ```"""
- # 1. In absence of `beam_indices`, we can assume that we come from e.g. greedy search, which is equivalent
- # to a beam search approach were the first (and only) beam is always selected
- if beam_indices is None:
- beam_indices = torch.arange(scores[0].shape[0]).view(-1, 1).to(sequences.device)
- beam_indices = beam_indices.expand(-1, len(scores))
- # 2. reshape scores as [batch_size*vocab_size, # generation steps] with # generation steps being
- # seq_len - input_length
- scores = torch.stack(scores).reshape(len(scores), -1).transpose(0, 1)
- # 3. Optionally normalize the logits (across the vocab dimension)
- if normalize_logits:
- scores = scores.reshape(-1, self.config.get_text_config().vocab_size, scores.shape[-1])
- scores = torch.nn.functional.log_softmax(scores, dim=1)
- scores = scores.reshape(-1, scores.shape[-1])
- # 4. cut beam_indices to longest beam length
- beam_indices_mask = beam_indices < 0
- max_beam_length = (1 - beam_indices_mask.long()).sum(-1).max()
- beam_indices = beam_indices.clone()[:, :max_beam_length]
- beam_indices_mask = beam_indices_mask[:, :max_beam_length]
- # 5. Set indices of beams that finished early to 0; such indices will be masked correctly afterwards
- beam_indices[beam_indices_mask] = 0
- # 6. multiply beam_indices with vocab size to gather correctly from scores
- beam_sequence_indices = beam_indices * self.config.get_text_config().vocab_size
- # 7. Define which indices contributed to scores
- cut_idx = sequences.shape[-1] - max_beam_length
- indices = sequences[:, cut_idx:] + beam_sequence_indices
- # 8. Compute scores
- transition_scores = scores.gather(0, indices)
- # 9. Mask out transition_scores of beams that stopped early
- transition_scores[beam_indices_mask] = 0
- return transition_scores
- def _validate_generation_mode(self, generation_mode, generation_config, generation_mode_kwargs):
- if generation_mode == GenerationMode.BEAM_SEARCH and "streamer" in generation_mode_kwargs:
- raise ValueError(
- "`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1."
- )
- if generation_mode == GenerationMode.ASSISTED_GENERATION:
- if generation_config.num_return_sequences > 1:
- raise ValueError(
- "num_return_sequences has to be 1 when doing assisted generate, "
- f"but is {generation_config.num_return_sequences}."
- )
- if self._is_stateful:
- # In assisted generation we need the ability to confirm whether the model would pick certain tokens,
- # which is not possible with stateful models (they can't reset to a previous subset of generated text)
- raise ValueError(
- f"assisted generation is not supported with stateful models, such as {self.__class__.__name__}"
- )
- if (assistant_model := generation_mode_kwargs.get("assistant_model")) is not None:
- if self.config.is_encoder_decoder and not assistant_model.config.is_encoder_decoder:
- attributes_to_check = ["encoder_attention_heads", "encoder_ffn_dim", "encoder_layers"]
- attributes_to_check = [attr for attr in dir(assistant_model.config) if attr in attributes_to_check]
- are_equal = all(
- getattr(self.config, attr) == getattr(assistant_model.config, attr) for attr in attributes_to_check
- )
- if not are_equal:
- raise ValueError(
- "The main model and the assistant don't have compatible encoder-dependent input shapes. "
- "Ensure you load the assistant with the correct encoder-decoder class, e.g. `AutoModelForSpeechSeq2Seq` for Whisper."
- )
- doc_reference = (
- "(see https://huggingface.co/docs/transformers/en/generation_strategies#universal-assisted-decoding)"
- )
- if self.config.get_text_config().vocab_size == assistant_model.config.get_text_config().vocab_size:
- if "assistant_tokenizer" in generation_mode_kwargs:
- raise ValueError(
- f"`assistant_tokenizer` is not required when the main and assistant models use the same tokenizer. Please omit `assistant_tokenizer` from `generate()` {doc_reference}."
- )
- else:
- if "tokenizer" not in generation_mode_kwargs or "assistant_tokenizer" not in generation_mode_kwargs:
- raise ValueError(
- f"The main and assistant models have different tokenizers. Please provide `tokenizer` and `assistant_tokenizer` to `generate()` {doc_reference}."
- )
- def _validate_model_kwargs(self, model_kwargs: dict[str, Any]):
- """Validates model kwargs for generation. Generate argument typos will also be caught here."""
- # Excludes arguments that are handled before calling any model function
- if self.config.is_encoder_decoder:
- for key in ["decoder_input_ids"]:
- model_kwargs.pop(key, None)
- unused_model_args = []
- model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters)
- # `kwargs`/`model_kwargs` is often used to handle optional forward pass inputs like `attention_mask`. If
- # `prepare_inputs_for_generation` doesn't accept them, then a stricter check can be made ;)
- if "kwargs" in model_args or "model_kwargs" in model_args:
- model_args |= set(inspect.signature(self.forward).parameters)
- # Encoder-Decoder models may also need Encoder arguments from `model_kwargs`
- if self.config.is_encoder_decoder:
- base_model = getattr(self, self.base_model_prefix, None)
- # allow encoder kwargs
- encoder = getattr(self, "encoder", None)
- # `MusicgenForConditionalGeneration` has `text_encoder` and `audio_encoder`.
- # Also, it has `base_model_prefix = "encoder_decoder"` but there is no `self.encoder_decoder`
- # TODO: A better way to handle this.
- if encoder is None and base_model is not None:
- encoder = getattr(base_model, "encoder", None)
- if encoder is not None:
- encoder_model_args = set(inspect.signature(encoder.forward).parameters)
- model_args |= encoder_model_args
- # allow decoder kwargs
- decoder = getattr(self, "decoder", None)
- if decoder is None and base_model is not None:
- decoder = getattr(base_model, "decoder", None)
- if decoder is not None:
- decoder_model_args = set(inspect.signature(decoder.forward).parameters)
- model_args |= {f"decoder_{x}" for x in decoder_model_args}
- # TransformersKwargs are model-agnostic attention and generation arguments such as 'output_attentions'
- for key, value in model_kwargs.items():
- if value is not None and key not in model_args and key not in TransformersKwargs.__optional_keys__:
- unused_model_args.append(key)
- if unused_model_args:
- raise ValueError(
- f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the"
- " generate arguments will also show up in this list)"
- )
- def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length):
- """Performs validation related to the resulting generated length"""
- # 1. Max length warnings related to poor parameterization
- if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20:
- # 20 is the default max_length of the generation config
- warnings.warn(
- f"Using the model-agnostic default `max_length` (={generation_config.max_length}) to control the "
- "generation length. We recommend setting `max_new_tokens` to control the maximum length of the "
- "generation.",
- UserWarning,
- )
- if input_ids_length >= generation_config.max_length:
- input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
- raise ValueError(
- f"Input length of {input_ids_string} is {input_ids_length}, but `max_length` is set to"
- f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
- " increasing `max_length` or, better yet, setting `max_new_tokens`."
- )
- # 2. Min length warnings due to unfeasible parameter combinations
- min_length_error_suffix = (
- " Generation will stop at the defined maximum length. You should decrease the minimum length and/or "
- "increase the maximum length."
- )
- if has_default_max_length:
- min_length_error_suffix += (
- f" Note that `max_length` is set to {generation_config.max_length}, its default value."
- )
- if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
- warnings.warn(
- f"Unfeasible length constraints: `min_length` ({generation_config.min_length}) is larger than"
- f" the maximum possible length ({generation_config.max_length})." + min_length_error_suffix,
- UserWarning,
- )
- if generation_config.min_new_tokens is not None:
- min_length = generation_config.min_new_tokens + input_ids_length
- if min_length > generation_config.max_length:
- warnings.warn(
- f"Unfeasible length constraints: `min_new_tokens` ({generation_config.min_new_tokens}), when "
- f"added to the prompt length ({input_ids_length}), is larger than"
- f" the maximum possible length ({generation_config.max_length})." + min_length_error_suffix,
- UserWarning,
- )
- def _prepare_generated_length(
- self,
- generation_config,
- has_default_max_length,
- has_default_min_length,
- model_input_name,
- input_ids_length,
- inputs_tensor,
- ):
- """Prepared max and min length in generation configs to avoid clashes between similar attributes"""
- if generation_config.max_new_tokens is not None:
- if not has_default_max_length and generation_config.max_length is not None:
- logger.warning(
- f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
- f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
- "Please refer to the documentation for more information. "
- "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
- )
- generation_config.max_length = generation_config.max_new_tokens + input_ids_length
- # if both `inputs_embeds` and `input_ids` are passed, we do not correct the length
- # otherwise we need total length [inputs-embeds-len + new-tokens-len] to not go beyond indicated `max_length``
- elif (
- model_input_name == "inputs_embeds"
- and input_ids_length != inputs_tensor.shape[1]
- and not self.config.is_encoder_decoder
- ):
- generation_config.max_length -= inputs_tensor.shape[1]
- elif has_default_max_length: # by default let's always generate 20 new tokens
- if generation_config.max_length == GenerationConfig().max_length:
- generation_config.max_length = generation_config.max_length + input_ids_length
- max_position_embeddings = getattr(self.config, "max_position_embeddings", None)
- if max_position_embeddings is not None:
- generation_config.max_length = min(generation_config.max_length, max_position_embeddings)
- # same for min length
- if generation_config.min_new_tokens is not None:
- if not has_default_min_length:
- logger.warning(
- f"Both `min_new_tokens` (={generation_config.min_new_tokens}) and `min_length`(="
- f"{generation_config.min_length}) seem to have been set. `min_new_tokens` will take precedence. "
- "Please refer to the documentation for more information. "
- "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
- )
- generation_config.min_length = generation_config.min_new_tokens + input_ids_length
- elif (
- model_input_name == "inputs_embeds"
- and input_ids_length != inputs_tensor.shape[1]
- and not self.config.is_encoder_decoder
- ):
- generation_config.min_length = max(generation_config.min_length - inputs_tensor.shape[1], 0)
- return generation_config
- def _prepare_generation_config(
- self,
- generation_config: Optional[GenerationConfig],
- use_model_defaults: Optional[bool] = None,
- **kwargs: Any,
- ) -> tuple[GenerationConfig, dict]:
- """
- Prepares the base generation config, then applies any generation configuration options from kwargs. This
- function handles retrocompatibility with respect to configuration files.
- """
- # parameterization priority:
- # kwargs > non-global default values in `generation_config` > `model.generation_config` > GenerationConfig()
- # TODO (joao): per-model generation config classes.
- using_model_generation_config = False
- if generation_config is None:
- # legacy: users may modify the model configuration to control generation. To trigger this legacy behavior,
- # the following conditions must be met
- # 1) the generation config must have been created from the model config (`_from_model_config` field);
- # 2) the generation config must have seen no modification since its creation (the hash is the same);
- # 3) there are non-default generation parameters in the model config.
- # 4) the user must have set new generation parameters in the model config.
- if (
- self.generation_config._from_model_config # 1)
- and self.generation_config._original_object_hash == hash(self.generation_config) # 2)
- and len(self.config._get_non_default_generation_parameters()) > 0 # 3)
- ):
- new_generation_config = GenerationConfig.from_model_config(self.config)
- if new_generation_config != self.generation_config: # 4)
- warnings.warn(
- "You have modified the pretrained model configuration to control generation. This is a"
- " deprecated strategy to control generation and will be removed in v5."
- " Please use and modify the model generation configuration (see"
- " https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )",
- UserWarning,
- )
- self.generation_config = new_generation_config
- generation_config = self.generation_config
- using_model_generation_config = True
- # Related to #40039: prior to this PR, models with sliding window attention were forced to have
- # `cache_implementation="hybrid"` (the static sliding window cache). For these models, we now want to use
- # the dynamic sliding window cache by default, so we UNSET `cache_implementation` if it is a default value.
- # (if we're inside this branch, then it is because we're using default values from the Hub)
- if generation_config.cache_implementation == "hybrid":
- generation_config.cache_implementation = None
- # `torch.export.export` usually raises an exception if it is called
- # with ``strict=True``. deepcopy can only be processed if ``strict=False``.
- generation_config = copy.deepcopy(generation_config)
- if not using_model_generation_config:
- # If `generation_config` is provided:
- # - `use_model_defaults`: let's fallback ALL default values to the model's generation config
- # - otherwise: legacy behavior, let's just make sure we have the tokens defined
- model_base_version = version.parse(version.parse(self.generation_config.transformers_version).base_version)
- if use_model_defaults is True or (
- use_model_defaults is None and model_base_version >= version.parse("4.50.0")
- ):
- modified_values = {}
- global_default_generation_config = GenerationConfig()
- model_generation_config = self.generation_config
- # we iterate over the model's generation config: it may hold custom keys, which we'll want to copy
- for key, model_gen_config_value in model_generation_config.__dict__.items():
- if key.startswith("_") or key == "transformers_version": # metadata
- continue
- # Don't set `cache_implementation = 'hybrid'` from the model defaults, see #40135
- if key == "cache_implementation" and model_generation_config.cache_implementation == "hybrid":
- continue
- global_default_value = getattr(global_default_generation_config, key, None)
- custom_gen_config_value = getattr(generation_config, key, None)
- if (
- custom_gen_config_value == global_default_value
- and model_gen_config_value != global_default_value
- ):
- modified_values[key] = model_gen_config_value
- setattr(generation_config, key, model_gen_config_value)
- # edge case: we may set `temperature=0.0` and `do_sample=False`, but the model defaults to
- # `do_sample=True`
- if generation_config.temperature == 0.0:
- generation_config.do_sample = False
- if use_model_defaults is None and len(modified_values) > 0:
- logger.warning_once(
- f"`generation_config` default values have been modified to match model-specific defaults: "
- f"{modified_values}. If this is not desired, please set these values explicitly."
- )
- else:
- if generation_config.bos_token_id is None:
- generation_config.bos_token_id = self.generation_config.bos_token_id
- if generation_config.eos_token_id is None:
- generation_config.eos_token_id = self.generation_config.eos_token_id
- if generation_config.pad_token_id is None:
- generation_config.pad_token_id = self.generation_config.pad_token_id
- if generation_config.decoder_start_token_id is None:
- generation_config.decoder_start_token_id = self.generation_config.decoder_start_token_id
- # Finally, apply any passed kwargs
- model_kwargs = generation_config.update(**kwargs)
- # And keep in model_kwargs variable output controls
- output_attentions = generation_config.output_attentions
- output_hidden_states = generation_config.output_hidden_states
- model_kwargs.update({"output_attentions": output_attentions} if output_attentions else {})
- model_kwargs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
- return generation_config, model_kwargs
- def _get_initial_cache_position(self, seq_length, device, model_kwargs):
- """Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length"""
- # `torch.compile`-friendly `torch.arange` from a shape -- the lines below are equivalent to `torch.arange`
- if "cache_position" in model_kwargs and model_kwargs["cache_position"] is not None:
- return model_kwargs
- if "inputs_embeds" in model_kwargs and not self.config.is_encoder_decoder:
- cache_position = torch.ones_like(model_kwargs["inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0) - 1
- elif "decoder_inputs_embeds" in model_kwargs and self.config.is_encoder_decoder:
- cache_position = (
- torch.ones_like(model_kwargs["decoder_inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0) - 1
- )
- else:
- cache_position = torch.ones(seq_length, dtype=torch.int64, device=device).cumsum(0) - 1
- past_length = 0
- if model_kwargs.get("past_key_values") is not None:
- cache = model_kwargs["past_key_values"]
- past_length = 0
- # Support for BC tuple cache format
- if isinstance(cache, tuple):
- past_length = cache[0][0].shape[2]
- elif hasattr(cache, "get_seq_length"):
- past_length = cache.get_seq_length()
- cache_position = cache_position[past_length:]
- model_kwargs["cache_position"] = cache_position
- return model_kwargs
- def _get_cache(self, cache_implementation: str, batch_size: int, max_cache_len: int, model_kwargs) -> Cache:
- """
- Sets a cache for `generate`, that will persist across calls. A new cache will only be initialized a
- new `generate` call requires a larger cache or uses a different batch size.
- Returns the resulting cache object.
- """
- requires_cross_attention_cache = (
- self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None
- )
- offload_cache = "offloaded" in cache_implementation
- if hasattr(self, "_cache"):
- cache_to_check = self._cache.self_attention_cache if requires_cross_attention_cache else self._cache
- need_new_cache = (
- not hasattr(self, "_cache")
- or cache_to_check.offloading != offload_cache
- or cache_to_check.max_batch_size != batch_size
- or cache_to_check.max_cache_len < max_cache_len
- )
- if requires_cross_attention_cache and hasattr(self, "_cache"):
- need_new_cache = (
- need_new_cache
- or self._cache.cross_attention_cache.max_cache_len != model_kwargs["encoder_outputs"][0].shape[1]
- )
- if need_new_cache:
- self_attention_cache_kwargs = {
- "config": self.config.get_text_config(decoder=True),
- "max_cache_len": max_cache_len,
- "offloading": offload_cache,
- }
- self._cache = StaticCache(**self_attention_cache_kwargs)
- if requires_cross_attention_cache:
- cross_attention_cache_kwargs = {
- "config": self.config.get_text_config(decoder=True),
- "max_cache_len": model_kwargs["encoder_outputs"][0].shape[1],
- "offloading": offload_cache,
- }
- self._cache = EncoderDecoderCache(self._cache, StaticCache(**cross_attention_cache_kwargs))
- else:
- self._cache.reset()
- return self._cache
- @classmethod
- def _supports_default_dynamic_cache(cls) -> bool:
- """
- Return `True` if current model can use a `DynamicCache` instance when initializing the `past_key_values`.
- This adds exception for some models like `Mamba` models which use their own caches
- and do not need to initialize the Cache in advance in order to save memory (because no back and forth
- `to_legacy_cache` and `from_legacy_cache` will be performed for mamba-based models).
- """
- # NOTE: remove xlnet/reformer when the models are deprecated, non-standard model architecture/cache name
- return not cls._is_stateful and all(
- special_model_name not in cls.__name__.lower()
- for special_model_name in [
- "reformer",
- "minimax",
- "xlnet",
- "lfm2",
- "lfm2-vl",
- ]
- )
- def _prepare_cache_for_generation(
- self,
- generation_config: GenerationConfig,
- model_kwargs: dict,
- generation_mode: GenerationMode,
- batch_size: int,
- max_cache_length: int,
- ) -> bool:
- """
- Prepares the cache for generation (if applicable), given `generate`'s parameterization. If a cache is
- instantiated, writes it to `model_kwargs`, under the name expected by the model.
- """
- is_hybrid_cache = any(class_name in self.__class__.__name__.lower() for class_name in ["mamba", "falconh1"])
- cache_name = "past_key_values" if not is_hybrid_cache else "cache_params"
- requires_cross_attention_cache = (
- self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None
- )
- # Quick escape route 1: if the user specifies a cache, we only need to:
- # a) check for conflicting `generate` arguments
- # b) convert to the new cache format (if the user passes a legacy cache and model supports it)
- user_defined_cache = model_kwargs.get(cache_name)
- if user_defined_cache is not None:
- if generation_config.cache_implementation is not None:
- raise ValueError(
- f"Passing both `cache_implementation` (used to initialize certain caches) and `{cache_name}` (a "
- "Cache object) is unsupported. Please use only one of the two."
- )
- if isinstance(user_defined_cache, tuple) and self._supports_default_dynamic_cache():
- logger.warning_once(
- "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
- "You should pass an instance of `Cache` instead."
- )
- model_kwargs[cache_name] = (
- DynamicCache.from_legacy_cache(user_defined_cache)
- if not requires_cross_attention_cache
- else EncoderDecoderCache.from_legacy_cache(user_defined_cache)
- )
- return
- # Quick escape route 2: if the user specifies no cache is to be used. (conflicting arguments are handled in
- # `generation_config.validate()`)
- if generation_config.use_cache is False:
- return
- # Quick escape route 3: model that only supports legacy caches or models that supply it in
- # `prepare_inputs_for_generation` (mamba, zamba, ...)
- if not self._supports_default_dynamic_cache():
- if generation_config.cache_implementation is not None:
- logger.warning_once(
- "This model does not support `Cache` instances. `cache_implementation` (set to "
- f"{generation_config.cache_implementation}) will be ignored.",
- )
- return
- # Otherwise we NEED to prepare a cache, based on `generation_config.cache_implementation`
- # TODO(joao): support static caches in assisted generation. assisted generation needs to roll back caches,
- # which is only supported in dynamic caches atm
- if (
- generation_mode == GenerationMode.ASSISTED_GENERATION
- and generation_config.cache_implementation is not None
- ):
- logger.warning_once(
- "An assistant model is provided, using a dynamic cache instead of a cache of type="
- f"'{generation_config.cache_implementation}'."
- )
- generation_config.cache_implementation = None
- # Assisted decoding and contrastive search require cache rollback, which is incompatible with sliding layers.
- # To handle this, we skip passing the model config to DynamicCache (forcing a full-layer cache).
- # The "dynamic_full" option is a shortcut for generate() users to avoid sliding layers on their own.
- if (
- generation_mode in (GenerationMode.ASSISTED_GENERATION, GenerationMode.CONTRASTIVE_SEARCH)
- or generation_config.cache_implementation == "dynamic_full"
- ):
- dynamic_cache_kwargs = {}
- else:
- dynamic_cache_kwargs = {"config": self.config.get_text_config(decoder=True)}
- if generation_config.cache_implementation is not None:
- if generation_config.cache_implementation in ALL_STATIC_CACHE_IMPLEMENTATIONS:
- if generation_config.cache_implementation in DEPRECATED_STATIC_CACHE_IMPLEMENTATIONS:
- logger.warning_once(
- f"Using `cache_implementation='{generation_config.cache_implementation}' is deprecated. "
- f"Please only use one of {STATIC_CACHE_IMPLEMENTATIONS}, and the layer structure will be "
- "inferred automatically."
- )
- model_kwargs[cache_name] = self._get_cache(
- cache_implementation=generation_config.cache_implementation,
- batch_size=max(generation_config.num_beams, generation_config.num_return_sequences) * batch_size,
- max_cache_len=max_cache_length,
- model_kwargs=model_kwargs,
- )
- elif generation_config.cache_implementation == "quantized":
- if self.config.is_encoder_decoder or not self._supports_default_dynamic_cache():
- raise ValueError(
- "This model does not support the quantized cache. If you want your model to support quantized "
- "cache, please open an issue and tag @zucchini-nlp."
- )
- cache_config = generation_config.cache_config if generation_config.cache_config is not None else {}
- # Add the config if it was not provided, as it's a required argument
- if "config" not in cache_config:
- cache_config["config"] = self.config.get_text_config()
- # Pop the backend from the config (defaults to quanto if not defined)
- backend = cache_config.pop("backend", "quanto")
- if backend == "quanto" and not is_optimum_quanto_available():
- raise ImportError(
- "You need to install optimum-quanto in order to use KV cache quantization with optimum-quanto "
- "backend. Please install it via with `pip install optimum-quanto`"
- )
- elif backend == "HQQ" and not is_hqq_available():
- raise ImportError(
- "You need to install `HQQ` in order to use KV cache quantization with HQQ backend. "
- "Please install it via with `pip install hqq`"
- )
- model_kwargs[cache_name] = QuantizedCache(backend=backend, **cache_config)
- elif generation_config.cache_implementation == "offloaded":
- model_kwargs[cache_name] = DynamicCache(**dynamic_cache_kwargs, offloading=True)
- elif "dynamic" in generation_config.cache_implementation:
- model_kwargs[cache_name] = DynamicCache(**dynamic_cache_kwargs)
- # Use DynamicCache instance by default. This will avoid back and forth from legacy format that
- # keeps copying the cache thus using much more memory
- # TODO (joao): remove this `else` when we remove the last traces of the legacy cache format (v4.58.0, search
- # for `instance(past_key_values, Cache)` as well). In general, if `cache_implementation` is unset, cache
- # initialization should happen inside the model at prefill time.
- else:
- model_kwargs[cache_name] = DynamicCache(**dynamic_cache_kwargs)
- # TODO (joao): this logic is incomplete, e.g. `offloaded` should apply to both caches. Refactor this function
- # to correctly pass parameterization to both caches.
- if requires_cross_attention_cache and not isinstance(model_kwargs[cache_name], EncoderDecoderCache):
- model_kwargs[cache_name] = EncoderDecoderCache(
- model_kwargs[cache_name], # self-attention cache
- DynamicCache(**dynamic_cache_kwargs), # cross-attention cache
- )
- def _supports_logits_to_keep(self) -> bool:
- """
- Return True if the current model supports the keyword argument `logits_to_keep` in forward()
- to save memory. Checking it in this way allows to avoid using a new model attribute.
- """
- return "logits_to_keep" in set(inspect.signature(self.forward).parameters.keys())
- def _prepare_special_tokens(
- self,
- generation_config: GenerationConfig,
- kwargs_has_attention_mask: Optional[bool] = None,
- device: Optional[Union[torch.device, str]] = None,
- ):
- """
- Prepares the special tokens for generation, overwriting the generation config with their processed versions
- converted to tensor.
- Note that `generation_config` is changed in place and stops being serializable after this method is called.
- That is no problem if called within `generate` (`generation_config` is a local copy that doesn't leave the
- function). However, if called outside `generate`, consider creating a copy of `generation_config` first.
- """
- # Convert special tokens to tensors
- def _tensor_or_none(token, device=None):
- if token is None:
- return token
- device = device if device is not None else self.device
- if isinstance(token, torch.Tensor):
- return token.to(device)
- return torch.tensor(token, device=device, dtype=torch.long)
- bos_token_tensor = _tensor_or_none(generation_config.bos_token_id, device=device)
- eos_token_tensor = _tensor_or_none(generation_config.eos_token_id, device=device)
- pad_token_tensor = _tensor_or_none(generation_config.pad_token_id, device=device)
- decoder_start_token_tensor = _tensor_or_none(generation_config.decoder_start_token_id, device=device)
- # for BC we also try to get `decoder_start_token_id` or `bos_token_id` (#30892)
- if self.config.is_encoder_decoder:
- decoder_start_token_tensor = (
- decoder_start_token_tensor if decoder_start_token_tensor is not None else bos_token_tensor
- )
- # We can have more than one eos token. Always treat it as a 1D tensor (when it exists).
- if eos_token_tensor is not None and eos_token_tensor.ndim == 0:
- eos_token_tensor = eos_token_tensor.unsqueeze(0)
- # Set pad token if unset (and there are conditions to do so)
- if pad_token_tensor is None and eos_token_tensor is not None:
- if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask:
- logger.warning(
- "The attention mask and the pad token id were not set. As a consequence, you may observe "
- "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
- )
- pad_token_tensor = eos_token_tensor[0]
- logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_tensor} for open-end generation.")
- # Sanity checks/warnings
- if self.config.is_encoder_decoder and decoder_start_token_tensor is None:
- raise ValueError(
- "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
- )
- if (
- eos_token_tensor is not None
- and isin_mps_friendly(elements=eos_token_tensor, test_elements=pad_token_tensor).any()
- ):
- if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask:
- logger.warning_once(
- "The attention mask is not set and cannot be inferred from input because pad token is same as "
- "eos token. As a consequence, you may observe unexpected behavior. Please pass your input's "
- "`attention_mask` to obtain reliable results."
- )
- if eos_token_tensor is not None and (
- torch.is_floating_point(eos_token_tensor) or (eos_token_tensor < 0).any()
- ):
- logger.warning(
- f"`eos_token_id` should consist of positive integers, but is {eos_token_tensor}. Your generation "
- "will not stop until the maximum length is reached. Depending on other flags, it may even crash."
- )
- # Update generation config with the updated special tokens tensors
- # NOTE: this must be written into a different attribute name than the one holding the original special tokens
- # (in their non-tensor form), in order to enable end-to-end compilation. See
- # https://pytorch.org/docs/stable/torch.compiler_cudagraph_trees.html#limitations
- generation_config._bos_token_tensor = bos_token_tensor
- generation_config._eos_token_tensor = eos_token_tensor
- generation_config._pad_token_tensor = pad_token_tensor
- generation_config._decoder_start_token_tensor = decoder_start_token_tensor
- def _valid_auto_compile_criteria(self, model_kwargs: dict[str, Any], generation_config: GenerationConfig) -> bool:
- """
- Determines whether to trigger auto-compilation of the model's forward pass at generation time.
- """
- # Override: honor `disable_compile` flag
- if generation_config.disable_compile:
- return False
- # Base logic
- valid_hardware = self.device.type == "cuda" or bool(
- generation_config.compile_config is not None and generation_config.compile_config._compile_all_devices
- )
- using_compilable_cache = (
- isinstance(model_kwargs.get("past_key_values"), Cache) and model_kwargs["past_key_values"].is_compileable
- )
- can_compile = valid_hardware and using_compilable_cache
- # Exception 1: Some quantization methods do not support compilation
- if getattr(self, "hf_quantizer", None) is not None:
- can_compile &= self.hf_quantizer.is_compileable
- if hasattr(self, "hf_device_map"):
- all_model_devices = set(self.hf_device_map.values())
- # Exception 2: Don't compile if the model is using CPU offload (as of April 2025, this results in a crash)
- has_cpu_offload = "cpu" in all_model_devices and len(all_model_devices) > 1
- can_compile &= not has_cpu_offload
- # Exception 3: Disk offload is not supported for compilation
- has_disk_offload = "disk" in all_model_devices
- can_compile &= not has_disk_offload
- # Finally: if the user has manually specified compilation options, but compilation is not possible, let's warn
- # them
- if generation_config.compile_config is not None and not can_compile:
- logger.warning_once(
- "You have set `compile_config`, but we are unable to meet the criteria for compilation. Compilation "
- "will be skipped."
- )
- return can_compile
- def _get_deprecated_gen_repo(
- self,
- generation_mode: GenerationMode,
- trust_remote_code: bool,
- custom_generate: Optional[str] = None,
- ) -> Optional[str]:
- """
- Returns the Hub repo for a deprecated generation mode, if any.
- """
- if custom_generate is not None or "/" not in (repo := GENERATION_MODES_MAPPING[generation_mode]):
- return None
- logger.warning_once(
- f"{generation_mode.name.replace('_', ' ').title()} was moved to a `custom_generate` repo: https://hf.co/{repo}. "
- f"To prevent loss of backward compatibility, add `custom_generate='{repo}'` "
- "to your `generate` call before v4.62.0."
- )
- if not trust_remote_code:
- raise ValueError(
- f"{generation_mode.name.replace('_', ' ').title()} requires `trust_remote_code=True` in your `generate` call, "
- f"since it loads https://hf.co/{repo}."
- )
- return repo
- def _extract_generation_mode_kwargs(
- self,
- custom_generate,
- kwargs,
- synced_gpus,
- assistant_model,
- streamer,
- ) -> dict[str, Any]:
- """
- Extracts and returns the generation mode related keyword arguments from the provided kwargs.
- """
- generation_mode_kwargs = {
- "tokenizer": kwargs.pop("tokenizer", None),
- "assistant_tokenizer": kwargs.pop("assistant_tokenizer", None),
- "assistant_model": assistant_model,
- "streamer": streamer,
- }
- generation_mode_kwargs["synced_gpus"] = (
- (is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)) and dist.get_world_size() > 1
- if synced_gpus is None
- else synced_gpus
- )
- generation_mode_kwargs = {k: v for k, v in generation_mode_kwargs.items() if v is not None}
- # Custom_generate callables can have their own set of arguments
- # To extract them, we compare the signature with the standard _sample method
- if isinstance(custom_generate, Callable):
- usual_mode_kwargs = inspect.signature(GenerationMixin._sample).parameters.keys()
- custom_generate_kwargs = inspect.signature(custom_generate).parameters.keys()
- new_custom_keys = custom_generate_kwargs - usual_mode_kwargs
- generation_mode_kwargs = {k: kwargs.pop(k) for k in new_custom_keys if k in kwargs}
- return generation_mode_kwargs
- @torch.no_grad()
- def generate(
- self,
- inputs: Optional[torch.Tensor] = None,
- generation_config: Optional[GenerationConfig] = None,
- logits_processor: Optional[LogitsProcessorList] = None,
- stopping_criteria: Optional[StoppingCriteriaList] = None,
- prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None,
- synced_gpus: Optional[bool] = None,
- assistant_model: Optional["PreTrainedModel"] = None,
- streamer: Optional["BaseStreamer"] = None,
- negative_prompt_ids: Optional[torch.Tensor] = None,
- negative_prompt_attention_mask: Optional[torch.Tensor] = None,
- use_model_defaults: Optional[bool] = None,
- custom_generate: Optional[Union[str, Callable]] = None,
- **kwargs,
- ) -> Union[GenerateOutput, torch.LongTensor]:
- r"""
- Generates sequences of token ids for models with a language modeling head.
- <Tip warning={true}>
- Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
- model's default generation configuration. You can override any `generation_config` by passing the corresponding
- parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`.
- For an overview of generation strategies and code examples, check out the [following
- guide](../generation_strategies).
- </Tip>
- Parameters:
- inputs (`torch.Tensor` of varying shape depending on the modality, *optional*):
- The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the
- method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs`
- should be in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of
- `input_ids`, `input_values`, `input_features`, or `pixel_values`.
- generation_config ([`~generation.GenerationConfig`], *optional*):
- The generation configuration to be used as base parametrization for the generation call. `**kwargs`
- passed to generate matching the attributes of `generation_config` will override them. If
- `generation_config` is not provided, the default will be used, which has the following loading
- priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
- configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
- default values, whose documentation should be checked to parameterize generation.
- logits_processor (`LogitsProcessorList`, *optional*):
- Custom logits processors that complement the default logits processors built from arguments and
- generation config. If a logit processor is passed that is already created with the arguments or a
- generation config an error is thrown. This feature is intended for advanced users.
- stopping_criteria (`StoppingCriteriaList`, *optional*):
- Custom stopping criteria that complements the default stopping criteria built from arguments and a
- generation config. If a stopping criteria is passed that is already created with the arguments or a
- generation config an error is thrown. If your stopping criteria depends on the `scores` input, make
- sure you pass `return_dict_in_generate=True, output_scores=True` to `generate`. This feature is
- intended for advanced users.
- prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], list[int]]`, *optional*):
- If provided, this function constraints the beam search to allowed tokens only at each step. If not
- provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and
- `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned
- on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful
- for constrained generation conditioned on the prefix, as described in [Autoregressive Entity
- Retrieval](https://huggingface.co/papers/2010.00904).
- synced_gpus (`bool`, *optional*):
- Whether to continue running the while loop until max_length. Unless overridden, this flag will be set
- to `True` if using `FullyShardedDataParallel` or DeepSpeed ZeRO Stage 3 with multiple GPUs to avoid
- deadlocking if one GPU finishes generating before other GPUs. Otherwise, defaults to `False`.
- assistant_model (`PreTrainedModel`, *optional*):
- An assistant model that can be used to accelerate generation. The assistant model must have the exact
- same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistant model
- is much faster than running generation with the model you're calling generate from. As such, the
- assistant model should be much smaller.
- streamer (`BaseStreamer`, *optional*):
- Streamer object that will be used to stream the generated sequences. Generated tokens are passed
- through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
- negative_prompt_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- The negative prompt needed for some processors such as CFG. The batch size must match the input batch
- size. This is an experimental feature, subject to breaking API changes in future versions.
- negative_prompt_attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Attention_mask for `negative_prompt_ids`.
- use_model_defaults (`bool`, *optional*):
- When it is `True`, unset parameters in `generation_config` will be set to the model-specific default
- generation configuration (`model.generation_config`), as opposed to the global defaults
- (`GenerationConfig()`). If unset, models saved starting from `v4.50` will consider this flag to be
- `True`.
- custom_generate (`str` or `Callable`, *optional*):
- One of the following:
- - `str` (Hugging Face Hub repository name): runs the custom `generate` function defined at
- `custom_generate/generate.py` in that repository instead of the standard `generate` method. The
- repository fully replaces the generation logic, and the return type may differ.
- - `str` (local repository path): same as above but from a local path, `trust_remote_code` not required.
- - `Callable`: `generate` will perform the usual input preparation steps, then call the provided callable to
- run the decoding loop.
- For more information, see [the docs](../../generation_strategies#custom-generation-methods).
- kwargs (`dict[str, Any]`, *optional*):
- Ad hoc parametrization of `generation_config` and/or additional model-specific kwargs that will be
- forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
- specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.
- Return:
- [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
- or when `config.return_dict_in_generate=True`) or a `torch.LongTensor`.
- If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
- [`~utils.ModelOutput`] types are:
- - [`~generation.GenerateDecoderOnlyOutput`],
- - [`~generation.GenerateBeamDecoderOnlyOutput`]
- If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
- [`~utils.ModelOutput`] types are:
- - [`~generation.GenerateEncoderDecoderOutput`],
- - [`~generation.GenerateBeamEncoderDecoderOutput`]
- """
- # 0. If requested, load an arbitrary generation recipe from the Hub and run it instead
- trust_remote_code = kwargs.pop("trust_remote_code", None)
- if custom_generate is not None and isinstance(custom_generate, str):
- # Get all `generate` arguments in a single variable. Custom functions are responsible for handling them:
- # they receive the same inputs as `generate`, with `model` instead of `self` and excluding the arguments to
- # trigger the custom generation. They can access to methods from `GenerationMixin` through `model`.
- global_keys_to_exclude = {
- "self",
- "kwargs",
- "global_keys_to_exclude",
- "trust_remote_code",
- "custom_generate",
- }
- generate_arguments = {key: value for key, value in locals().items() if key not in global_keys_to_exclude}
- generate_arguments.update(kwargs)
- custom_generate_function = self.load_custom_generate(
- custom_generate, trust_remote_code=trust_remote_code, **kwargs
- )
- return custom_generate_function(model=self, **generate_arguments)
- # 1. Handle kwargs, `generation_config`, validate them and obtain generation mode
- generation_mode_kwargs = self._extract_generation_mode_kwargs(
- custom_generate,
- kwargs,
- synced_gpus,
- assistant_model,
- streamer,
- )
- generation_config, model_kwargs = self._prepare_generation_config(
- generation_config, use_model_defaults, **kwargs
- )
- generation_mode = generation_config.get_generation_mode(assistant_model)
- if isinstance(custom_generate, Callable):
- decoding_method = custom_generate
- else:
- # type() required to access the unbound class-level method
- decoding_method = getattr(type(self), GENERATION_MODES_MAPPING[generation_mode])
- self._validate_model_kwargs(model_kwargs.copy())
- self._validate_generation_mode(generation_mode, generation_config, generation_mode_kwargs)
- # Deprecation-related step: set Hub repo for deprecated strategies.
- # NOTE: This must come after initializing generation_config, since we need it to determine if this is a deprecated mode.
- # It must also be before any preparation steps, since Hub repos expect to be loaded before preparation steps.
- # TODO joao, manuel: remove this in v4.62.0
- if deprecated_mode_repo := self._get_deprecated_gen_repo(generation_mode, trust_remote_code, custom_generate):
- return GenerationMixin.generate(
- self,
- inputs=inputs,
- generation_config=generation_config,
- logits_processor=logits_processor,
- stopping_criteria=stopping_criteria,
- prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
- assistant_model=assistant_model,
- negative_prompt_ids=negative_prompt_ids,
- negative_prompt_attention_mask=negative_prompt_attention_mask,
- use_model_defaults=use_model_defaults,
- custom_generate=deprecated_mode_repo,
- trust_remote_code=trust_remote_code,
- **generation_mode_kwargs,
- **kwargs,
- )
- # 2. Set generation parameters if not already defined
- logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
- stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
- accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
- requires_attention_mask = "encoder_outputs" not in model_kwargs
- kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
- # 3. Define model inputs
- inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
- inputs, generation_config.bos_token_id, model_kwargs
- )
- # Some generation modes (e.g. assisted) need `inputs_tensor` to rerun encoder.forward()
- if "inputs_tensor" in inspect.signature(decoding_method).parameters.keys():
- generation_mode_kwargs["inputs_tensor"] = inputs_tensor
- batch_size = inputs_tensor.shape[0]
- device = inputs_tensor.device
- self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device)
- # decoder-only models must use left-padding for batched generation.
- if not self.config.is_encoder_decoder:
- # If `input_ids` was given, check if the last id in any sequence is `pad_token_id`
- # Note: If using, `inputs_embeds` this check does not work, because we want to be more hands-off.
- if (
- generation_config._pad_token_tensor is not None
- and batch_size > 1
- and len(inputs_tensor.shape) == 2
- and torch.sum(inputs_tensor[:, -1] == generation_config._pad_token_tensor) > 0
- ):
- logger.warning(
- "A decoder-only architecture is being used, but right-padding was detected! For correct "
- "generation results, please set `padding_side='left'` when initializing the tokenizer."
- )
- # 4. Define other model kwargs
- # decoder-only models with inputs_embeds forwarding must use caching (otherwise we can't detect whether we are
- # generating the first new token or not, and we only want to use the embeddings for the first new token)
- if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds":
- generation_config.use_cache = True
- if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask:
- model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
- inputs_tensor, generation_config, model_kwargs
- )
- elif kwargs_has_attention_mask:
- # TODO (joao): generalize this check with other types of inputs
- if model_input_name == "input_ids" and len(model_kwargs["attention_mask"].shape) > 2:
- raise ValueError("`attention_mask` passed to `generate` must be 2D.")
- if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
- # if model is encoder decoder encoder_outputs are created and added to `model_kwargs`
- model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
- inputs_tensor, model_kwargs, model_input_name, generation_config
- )
- # 5. Prepare `input_ids` which will be used for auto-regressive generation
- if self.config.is_encoder_decoder:
- input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
- batch_size=batch_size,
- model_input_name=model_input_name,
- model_kwargs=model_kwargs,
- decoder_start_token_id=generation_config._decoder_start_token_tensor,
- device=inputs_tensor.device,
- )
- else:
- input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
- # Expand inputs depending on the generation mode
- input_ids, model_kwargs = self._expand_inputs_for_generation(
- input_ids=input_ids,
- expand_size=max(generation_config.num_beams, generation_config.num_return_sequences),
- is_encoder_decoder=self.config.is_encoder_decoder,
- **model_kwargs,
- )
- if generation_config.token_healing:
- input_ids = self.heal_tokens(input_ids, generation_mode_kwargs.get("tokenizer"))
- if streamer is not None:
- streamer.put(input_ids.cpu())
- # 6. Prepare `max_length` depending on other stopping criteria.
- input_ids_length = input_ids.shape[1]
- has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
- has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
- generation_config = self._prepare_generated_length(
- generation_config=generation_config,
- has_default_max_length=has_default_max_length,
- has_default_min_length=has_default_min_length,
- model_input_name=model_input_name,
- inputs_tensor=inputs_tensor,
- input_ids_length=input_ids_length,
- )
- # If the model supports `logits_to_keep` in forward(), set it to 1 to avoid computing the whole
- # logit matrix. This can save a lot of memory during the first forward pass. Note that assisted decoding
- # dynamically overrides this value as it can need more than the last token logits
- if self._supports_logits_to_keep() and "logits_to_keep" not in model_kwargs:
- model_kwargs["logits_to_keep"] = 1
- self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
- # 7. Prepare the cache.
- # - `model_kwargs` may be updated in place with a cache as defined by the parameters in `generation_config`.
- # - different models have a different cache name expected by the model (default = "past_key_values")
- # - `max_length`, prepared above, is used to determine the maximum cache length
- max_cache_length = generation_config.max_length - 1
- if (
- inputs_tensor.shape[1] != input_ids_length
- and model_input_name == "inputs_embeds"
- and not self.config.is_encoder_decoder
- ):
- max_cache_length += inputs_tensor.shape[1]
- self._prepare_cache_for_generation(
- generation_config, model_kwargs, generation_mode, batch_size, max_cache_length
- )
- if self.device.type != input_ids.device.type:
- warnings.warn(
- "You are calling .generate() with the `input_ids` being on a device type different"
- f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model"
- f" is on {self.device.type}. You may experience unexpected behaviors or slower generation."
- " Please make sure that you have put `input_ids` to the"
- f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before"
- " running `.generate()`.",
- UserWarning,
- )
- # 8. prepare logits processors and stopping criteria
- prepared_logits_processor = self._get_logits_processor(
- generation_config=generation_config,
- input_ids_seq_length=input_ids_length,
- encoder_input_ids=inputs_tensor,
- prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
- logits_processor=logits_processor,
- device=inputs_tensor.device,
- model_kwargs=model_kwargs,
- negative_prompt_ids=negative_prompt_ids,
- negative_prompt_attention_mask=negative_prompt_attention_mask,
- )
- prepared_stopping_criteria = self._get_stopping_criteria(
- generation_config=generation_config,
- stopping_criteria=stopping_criteria,
- tokenizer=generation_mode_kwargs.get("tokenizer"),
- )
- # Set model_kwargs `use_cache` so we can use it later in forward runs
- model_kwargs["use_cache"] = generation_config.use_cache
- # 9. Call generation mode
- result = decoding_method(
- self,
- input_ids,
- logits_processor=prepared_logits_processor,
- stopping_criteria=prepared_stopping_criteria,
- generation_config=generation_config,
- **generation_mode_kwargs,
- **model_kwargs,
- )
- # Convert to legacy cache format if requested
- if (
- generation_config.return_legacy_cache is True
- and hasattr(result, "past_key_values")
- and getattr(result.past_key_values, "to_legacy_cache") is not None
- ):
- result.past_key_values = result.past_key_values.to_legacy_cache()
- return result
- def _has_unfinished_sequences(self, this_peer_finished: bool, synced_gpus: bool, device: torch.device) -> bool:
- """
- Returns whether there are still unfinished sequences in the device. The existence of unfinished sequences is
- fed through `this_peer_finished`. ZeRO stage 3-friendly.
- """
- if synced_gpus:
- # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
- # The following logic allows an early break if all peers finished generating their sequence
- this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0, device=device)
- # send 0.0 if we finished, 1.0 otherwise
- dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
- # did all peers finish? the reduced sum will be 0.0 then
- if this_peer_finished_flag.item() == 0.0:
- return False
- elif this_peer_finished:
- return False
- return True
- def heal_tokens(
- self, input_ids: torch.LongTensor, tokenizer: Optional["PreTrainedTokenizerBase"] = None
- ) -> torch.LongTensor:
- r"""
- Generates sequences of token ids for models with a language modeling head.
- Parameters:
- input_ids (`torch.LongTensor`): The sequence used as a prompt for the generation.
- tokenizer (`PreTrainedTokenizerBase`, *optional*): The tokenizer used to decode the input ids.
- Return:
- `torch.LongTensor` where each sequence has its tail token replaced with its appropriate extension.
- """
- if tokenizer is None:
- raise ValueError(
- " When generating with token healing, you must pass the model's tokenizer to the `tokenizer` "
- "argument of `generate`."
- )
- bos_token_id, pad_token_id = tokenizer.bos_token_id, tokenizer.pad_token_id
- vocab_trie = ExtensionsTrie(tokenizer.get_vocab())
- generation_config = GenerationConfig(max_new_tokens=1, pad_token_id=pad_token_id)
- # assumption: leading/trailing whitespace is not meaningful, so the prompts are
- # stripped before re-tokenizing to desensitize generation to whitespace artefacts
- prompts = [p.strip() for p in tokenizer.batch_decode(input_ids, skip_special_tokens=True)]
- input_ids = tokenizer(
- prompts,
- return_tensors="pt",
- padding=True,
- ).input_ids.to(input_ids.device)
- # replace bos with pad to not condition healing on it
- input_ids = torch.where(input_ids == bos_token_id, pad_token_id, input_ids)
- # the latter code assumes the input_ids is not empty, input_id has to be checked if contains elements
- if input_ids.numel() == 0:
- return input_ids
- tail_ids = input_ids[:, -1].tolist()
- # tail tokens are used for a prefix search, thus, whitespaces are replaced with
- # their tokenization (e.g. 'Ġ') to enable search for tokens prefixed with a whitespace
- if tokenizer.convert_tokens_to_ids(" ") is not None:
- space_tok = tokenizer.convert_ids_to_tokens(tokenizer.convert_tokens_to_ids(" "))[0]
- tail_toks = (tokenizer.decode(t).replace(" ", space_tok) for t in tail_ids)
- else:
- tail_toks = (tokenizer.decode(t) for t in tail_ids)
- for batch_idx, (tail_id, tail_tok) in enumerate(zip(tail_ids, tail_toks)):
- batch_ids = input_ids[batch_idx]
- if torch.all(batch_ids == pad_token_id).item():
- continue # skip empty sequences (all pad ids)
- # apply bias for alternatives (extensions) to the tail token
- """
- seq_bias key has to be tuple with int so have to use
- tokenizer function to convert str to int
- """
- seq_bias = {
- (tokenizer.convert_tokens_to_ids(alt_tok),): 10.0 for alt_tok in vocab_trie.extensions(prefix=tail_tok)
- }
- if len(seq_bias) == 1:
- continue # skip if there are no token alternatives to heal with
- # slightly favor original token to limit aggressive healing e.g. 'http' -> 'https'
- seq_bias[(tail_id,)] += 1.0
- generation_config.update(sequence_bias=seq_bias)
- trimmed_ids = batch_ids[:-1]
- """
- the latter code assumes trimmed_ids is not empty
- so have to check the its element count
- """
- if trimmed_ids.numel() == 0:
- continue
- # if the prompt is a single (non-pad) token, regenerate from bos
- if len(batch_ids[batch_ids != pad_token_id]) == 1:
- trimmed_ids[-1] = bos_token_id
- input_ids[batch_idx] = self.generate(trimmed_ids.unsqueeze(0), generation_config=generation_config)
- return input_ids
- def _sample(
- self,
- input_ids: torch.LongTensor,
- logits_processor: LogitsProcessorList,
- stopping_criteria: StoppingCriteriaList,
- generation_config: GenerationConfig,
- synced_gpus: bool = False,
- streamer: Optional["BaseStreamer"] = None,
- **model_kwargs,
- ) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
- r"""
- Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and
- can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
- Parameters:
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- The sequence used as a prompt for the generation.
- logits_processor (`LogitsProcessorList`):
- An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
- used to modify the prediction scores of the language modeling head applied at each generation step.
- stopping_criteria (`StoppingCriteriaList`):
- An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
- used to tell if the generation loop should stop.
- generation_config ([`~generation.GenerationConfig`]):
- The generation configuration to be used as parametrization of the decoding method.
- synced_gpus (`bool`):
- Whether to continue running the while loop until max_length (needed to avoid deadlocking with
- `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
- streamer (`BaseStreamer`, *optional*):
- Streamer object that will be used to stream the generated sequences. Generated tokens are passed
- through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
- model_kwargs:
- Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
- an encoder-decoder model the kwargs should include `encoder_outputs`.
- Return:
- [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or `torch.LongTensor`:
- A `torch.LongTensor` containing the generated tokens (default behaviour) or a
- [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
- `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if
- `model.config.is_encoder_decoder=True`.
- """
- # init values
- pad_token_id = generation_config._pad_token_tensor
- output_attentions = generation_config.output_attentions
- output_hidden_states = generation_config.output_hidden_states
- output_scores = generation_config.output_scores
- output_logits = generation_config.output_logits
- return_dict_in_generate = generation_config.return_dict_in_generate
- has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
- do_sample = generation_config.do_sample
- # init attention / hidden states / scores tuples
- scores = () if (return_dict_in_generate and output_scores) else None
- raw_logits = () if (return_dict_in_generate and output_logits) else None
- decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
- cross_attentions = () if (return_dict_in_generate and output_attentions) else None
- decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
- # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
- if return_dict_in_generate and self.config.is_encoder_decoder:
- encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
- encoder_hidden_states = (
- model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
- )
- # keep track of which sequences are already finished
- batch_size, cur_len = input_ids.shape[:2]
- this_peer_finished = False
- unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
- model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
- model_forward = self.__call__
- compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config)
- if compile_forward:
- os.environ["TOKENIZERS_PARALLELISM"] = "0"
- # If we use FA2 and a static cache, we cannot compile with fullgraph
- if self.config._attn_implementation == "flash_attention_2":
- # only raise warning if the user passed an explicit compile-config
- if generation_config.compile_config is not None and generation_config.compile_config.fullgraph:
- logger.warning_once(
- "When using Flash Attention 2 and a static cache, you cannot use the option `CompileConfig(fullgraph=True)` as "
- "FA2 introduces graph breaks. We overrode the option with `fullgraph=False`."
- )
- generation_config.compile_config.fullgraph = False
- model_forward = self.get_compiled_call(generation_config.compile_config)
- if generation_config.prefill_chunk_size is not None:
- model_kwargs = self._prefill_chunking(input_ids, generation_config, **model_kwargs)
- is_prefill = False
- else:
- is_prefill = True
- while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
- # prepare model inputs
- model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
- if is_prefill:
- outputs = self(**model_inputs, return_dict=True)
- is_prefill = False
- else:
- outputs = model_forward(**model_inputs, return_dict=True)
- # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
- model_kwargs = self._update_model_kwargs_for_generation(
- outputs,
- model_kwargs,
- is_encoder_decoder=self.config.is_encoder_decoder,
- )
- if synced_gpus and this_peer_finished:
- continue
- # Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
- # (the clone itself is always small)
- next_token_logits = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device)
- # pre-process distribution
- next_token_scores = logits_processor(input_ids, next_token_logits)
- # Store scores, attentions and hidden_states when required
- if return_dict_in_generate:
- if output_scores:
- scores += (next_token_scores,)
- if output_logits:
- raw_logits += (next_token_logits,)
- if output_attentions:
- decoder_attentions += (
- (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
- )
- if self.config.is_encoder_decoder:
- cross_attentions += (outputs.cross_attentions,)
- if output_hidden_states:
- decoder_hidden_states += (
- (outputs.decoder_hidden_states,)
- if self.config.is_encoder_decoder
- else (outputs.hidden_states,)
- )
- # token selection
- if do_sample:
- probs = nn.functional.softmax(next_token_scores, dim=-1)
- # TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution
- next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
- else:
- next_tokens = torch.argmax(next_token_scores, dim=-1)
- # finished sentences should have their next token be a padding token
- if has_eos_stopping_criteria:
- next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
- # update generated ids, model inputs, and length for next step
- input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
- if streamer is not None:
- streamer.put(next_tokens.cpu())
- unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
- this_peer_finished = unfinished_sequences.max() == 0
- cur_len += 1
- # This is needed to properly delete outputs.logits which may be very large for first iteration
- # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
- del outputs
- if streamer is not None:
- streamer.end()
- if return_dict_in_generate:
- if self.config.is_encoder_decoder:
- return GenerateEncoderDecoderOutput(
- sequences=input_ids,
- scores=scores,
- logits=raw_logits,
- encoder_attentions=encoder_attentions,
- encoder_hidden_states=encoder_hidden_states,
- decoder_attentions=decoder_attentions,
- cross_attentions=cross_attentions,
- decoder_hidden_states=decoder_hidden_states,
- past_key_values=model_kwargs.get("past_key_values"),
- )
- else:
- return GenerateDecoderOnlyOutput(
- sequences=input_ids,
- scores=scores,
- logits=raw_logits,
- attentions=decoder_attentions,
- hidden_states=decoder_hidden_states,
- past_key_values=model_kwargs.get("past_key_values"),
- )
- else:
- return input_ids
- @staticmethod
- def _flatten_beam_dim(tensor: torch.Tensor) -> torch.Tensor:
- """[batch_size, num_beams, ...] -> [batch_size * num_beams, ...]"""
- shape = list(tensor.shape)
- return torch.reshape(tensor, [shape[0] * shape[1]] + shape[2:])
- @staticmethod
- def _unflatten_beam_dim(tensor: torch.Tensor, batch_size: int, num_beams: int) -> torch.Tensor:
- """[batch_size * num_beams, ...] -> [batch_size, num_beams, ...]"""
- shape = list(tensor.shape)
- return torch.reshape(tensor, [batch_size, num_beams] + shape[1:])
- @staticmethod
- def _gather_beams(tensor: torch.Tensor, beam_indices: torch.Tensor) -> torch.Tensor:
- """
- Gathers the beam slices indexed by beam_indices into new beam array.
- Args:
- tensor (`torch.Tensor`): A tensor containing data to be gathered. The tensor is a 2D or a 3D tensor
- with the two first dimensions depicting the batch and the beam dimensions.
- beam_indices (`torch.Tensor` of shape `(batch_size, num_beams_to_select)`): The indices of the beams to
- select .
- Returns:
- A tensor with the selected beams
- """
- # `take_along_dim` requires its indices arg to have the same number of dims as `input`
- while len(beam_indices.shape) < len(tensor.shape):
- beam_indices = beam_indices.unsqueeze(-1)
- gathered_tensor = torch.take_along_dim(input=tensor, indices=beam_indices, dim=1)
- return gathered_tensor
- @staticmethod
- def _check_early_stop_heuristic(
- is_early_stop_heuristic_unsatisfied: torch.Tensor,
- running_beam_scores: torch.Tensor,
- beam_scores: torch.Tensor,
- is_sent_finished: torch.Tensor,
- cur_len: int,
- max_length: int,
- decoder_prompt_len: int,
- early_stopping: Union[bool, str],
- length_penalty: float,
- ):
- """
- Determine whether early stopping is possible by checking if the best possible score of running beams
- could still improve upon the finished ones.
- Mechanism:
- - Without a length penalty, beam scores typically decrease as more tokens are generated.
- So, if the *best possible* score from any running beam is already worse than the *worst* finished beam,
- we can safely stop early.
- - With a length penalty, scores may increase with longer sequences. In this case, we use heuristics
- to estimate the best possible score — though this estimate may not always be correct — and stop
- if no further improvement seems likely.
- We apply different heuristics depending on the value of `early_stopping`:
- 1. `early_stopping == False`:
- -> Use a heuristic that assumes the best score comes from the current length minus the decoder prompt length.
- -> See detailed discussion: https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565
- 2. `early_stopping == "never"`:
- -> Estimate the best score using either `max_length` or `cur_len`, depending on the sign of `length_penalty`.
- -> A positive length penalty favors longer sequences, so we use `max_length` in that case.
- NOTE: the canonical beam search implementation can be replicated with `early_stopping="never"` and
- `length_penalty=0.0`, which are NOT the default flags. The default behavior was empirically found to produce
- better sequences (prior to 2022), and changing it is BC breaking.
- """
- if early_stopping == "never" and length_penalty > 0.0:
- best_hypothetical_length = max_length - decoder_prompt_len
- else:
- best_hypothetical_length = cur_len - decoder_prompt_len
- best_possible_running_score = running_beam_scores[:, :1] / (best_hypothetical_length**length_penalty)
- worst_finished_score = torch.where(is_sent_finished, torch.min(beam_scores, dim=1, keepdim=True)[0], -1.0e9)
- return is_early_stop_heuristic_unsatisfied & torch.any(
- best_possible_running_score > worst_finished_score, dim=-1, keepdim=True
- )
- @staticmethod
- def _beam_search_has_unfinished_sequences(
- is_early_stop_heuristic_unsatisfied: torch.Tensor,
- is_sent_finished: torch.Tensor,
- next_token_hits_stopping_criteria: torch.Tensor,
- early_stopping: Union[bool, str],
- ):
- """
- Beam Search stopping condition -- halts the generation loop if any of these conditions becomes False
- """
- # a. Can the open beams improve the top completed scores?
- improvement_possible = torch.any(is_early_stop_heuristic_unsatisfied)
- # b. Is there still a beam without fully completed sequences? This is only relevant if early_stopping is
- # enabled, where we want to finish as soon as all beams have a completed sequence.
- exists_open_beam = ~(torch.all(is_sent_finished) & (early_stopping is True))
- # c. Have we hit a stopping criteria with all running sequences and have no way to continue? e.g. we have
- # reached `max_length``
- valid_continuations = ~torch.all(next_token_hits_stopping_criteria)
- return improvement_possible & exists_open_beam & valid_continuations
- def _get_top_k_continuations(
- self,
- accumulated_log_probs: torch.Tensor,
- running_sequences: torch.Tensor,
- running_beam_indices: torch.Tensor,
- cur_len: int,
- decoder_prompt_len: int,
- do_sample: bool,
- beams_to_keep: int,
- num_beams: int,
- vocab_size: int,
- batch_size: int,
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- """
- Get top-K continuations given the accumulated log probs on the next token.
- A few notes to understand what's going on:
- 1. Each item in batch has `num_beams` * `vocab_size` candidate continuations. For each item, get the
- top K [K = (number of EOS tokens + 1) * `num_beams`] candidates with the highest accumulated
- log-probabilities, or sample them without replacement using the accumulated scores
- 2. We gather the top K (as opposed to `num_beams`, or any number lower than K) here so that we have at
- least `num_beams` sequences remaining to continue the live beam search.
- 3. Note that other stopping criteria might result in impossible to continue beams, i.e. all continuations
- selected in this step hit the stopping criteria.
- """
- # TODO (joao): This function should take an optional beam scorer function, to manipulate the scores after
- # token selection. The function should be an argument exposed, so that custom scoring functions can be
- # defined.
- # Gather the top K scores from _all_ beams.
- if do_sample:
- topk_indices = torch.multinomial(
- nn.functional.softmax(accumulated_log_probs, dim=-1), num_samples=beams_to_keep
- )
- topk_log_probs = torch.gather(input=accumulated_log_probs, dim=1, index=topk_indices)
- else:
- topk_log_probs, topk_indices = torch.topk(accumulated_log_probs, k=beams_to_keep)
- # Gather K top beams, recover the beam index by floor division and token id by modulo division
- topk_current_beam_indices = topk_indices // vocab_size
- topk_running_beam_indices = self._gather_beams(running_beam_indices, topk_current_beam_indices)
- topk_running_sequences = self._gather_beams(running_sequences, topk_current_beam_indices)
- topk_ids = topk_indices % vocab_size
- # Update sequences for the K top-k new sequences.
- topk_running_sequences[:, :, cur_len] = topk_ids
- # we want to store the beam indices with batch information -> real beam index = beam index % num beams
- batch_offset = torch.arange(batch_size, device=topk_ids.device).view(-1, 1) * num_beams
- batch_modified_indices = topk_current_beam_indices + batch_offset
- topk_running_beam_indices[:, :, cur_len - decoder_prompt_len] = batch_modified_indices
- return topk_log_probs, topk_running_sequences, topk_running_beam_indices
- def _get_running_beams_for_next_iteration(
- self,
- topk_log_probs: torch.Tensor,
- topk_running_sequences: torch.Tensor,
- topk_running_beam_indices: torch.Tensor,
- next_token_hits_stopping_criteria: torch.Tensor,
- num_beams: int,
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- """
- Given the top-K continuations, their scores, and whether they hit a stopping criteria, select the
- best non-finished beams to continue beam search in the next iteration.
- """
- # To prevent these just finished sequences from being used in subsequent iterations, set their log probs
- # to a very large negative value
- topk_running_log_probs = topk_log_probs + next_token_hits_stopping_criteria.to(torch.float32) * -1.0e9
- next_topk_indices = torch.topk(topk_running_log_probs, k=num_beams)[1]
- running_sequences = self._gather_beams(topk_running_sequences, next_topk_indices)
- running_beam_scores = self._gather_beams(topk_running_log_probs, next_topk_indices)
- running_beam_indices = self._gather_beams(topk_running_beam_indices, next_topk_indices)
- return running_sequences, running_beam_scores, running_beam_indices
- def _update_finished_beams(
- self,
- sequences: torch.Tensor,
- topk_running_sequences: torch.Tensor,
- beam_scores: torch.Tensor,
- topk_log_probs: torch.Tensor,
- beam_indices: torch.Tensor,
- topk_running_beam_indices: torch.Tensor,
- is_early_stop_heuristic_unsatisfied: torch.Tensor,
- is_sent_finished: torch.Tensor,
- next_token_hits_stopping_criteria: torch.Tensor,
- top_num_beam_mask: torch.Tensor,
- num_beams: int,
- cur_len: int,
- decoder_prompt_len: int,
- length_penalty: float,
- early_stopping: Union[bool, str],
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
- """
- Updates the finished beams if (and only if) there are new completed sequences that have a higher score than
- the current finished sequences.
- """
- # Only the top `num_beam` sequences can be considered for the final returned sequences. Remember: the
- # remaining sequences only exist as a backup to ensure that we have at least `num_beams` sequences to
- # continue.
- did_top_num_beams_just_finished = next_token_hits_stopping_criteria & top_num_beam_mask[None, :]
- # Further process topk logits for the finished beams
- # - add length penalty
- topk_log_probs = topk_log_probs / ((cur_len + 1 - decoder_prompt_len) ** length_penalty)
- # - make sure no scores can be added anymore if beam is full and early stopping is on
- beams_in_batch_are_full = torch.all(is_sent_finished, axis=-1, keepdims=True) & (early_stopping is True)
- topk_log_probs += beams_in_batch_are_full.to(torch.float32) * -1.0e9
- # - make sure no scores can be added anymore if improvement is not possible
- topk_log_probs += (~is_early_stop_heuristic_unsatisfied).to(torch.float32) * -1.0e9
- # - make sure still running sequences cannot be chosen as finalized beam
- topk_log_probs += (~did_top_num_beams_just_finished) * -1.0e9
- # Get finalized `num_beam` sequences for the next generation step -- combine the previous finalized
- # data with the new finalized sequences (if any, non-finalized sequences have a very large negative score
- # in this step), and keep the best `num_beams` sequences.
- merged_sequences = torch.cat((sequences, topk_running_sequences), dim=1)
- merged_scores = torch.cat((beam_scores, topk_log_probs), dim=1)
- merged_beam_indices = torch.cat((beam_indices, topk_running_beam_indices), dim=1)
- merged_is_sent_finished = torch.cat((is_sent_finished, did_top_num_beams_just_finished), dim=1)
- topk_merged_indices = torch.topk(merged_scores, k=num_beams)[1]
- sequences = self._gather_beams(merged_sequences, topk_merged_indices)
- beam_scores = self._gather_beams(merged_scores, topk_merged_indices)
- beam_indices = self._gather_beams(merged_beam_indices, topk_merged_indices)
- is_sent_finished = self._gather_beams(merged_is_sent_finished, topk_merged_indices)
- return sequences, beam_scores, beam_indices, is_sent_finished
- # end of auxiliary functions for beam search
- def _beam_search(
- self,
- input_ids: torch.LongTensor,
- logits_processor: LogitsProcessorList,
- stopping_criteria: StoppingCriteriaList,
- generation_config: GenerationConfig,
- synced_gpus: bool = False,
- **model_kwargs,
- ) -> Union[GenerateBeamOutput, torch.LongTensor]:
- r"""
- Generates sequences of token ids for models with a language modeling head using **beam search decoding** and
- can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
- If it's the first time you're diving into Beam Search, we recommend you read the following blog post:
- https://huggingface.co/blog/how-to-generate (especially the beam search section).
- You can recompute the sequence scores from the individual scores using the `compute_transition_scores` function
- (https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationMixin.compute_transition_scores)
- Parameters:
- input_ids (`torch.LongTensor` of shape `(batch_size*num_beams, sequence_length)`):
- The sequence used as a prompt for the generation.
- logits_processor (`LogitsProcessorList`):
- An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
- used to modify the prediction scores of the language modeling head applied at each generation step.
- stopping_criteria (`StoppingCriteriaList`:
- An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
- used to tell if the generation loop should stop.
- generation_config ([`~generation.GenerationConfig`]):
- The generation configuration to be used as parametrization of the decoding method.
- synced_gpus (`bool`):
- Whether to continue running the while loop until max_length (needed to avoid deadlocking with
- `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
- model_kwargs:
- Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
- an encoder-decoder model the kwargs should include `encoder_outputs`.
- Return:
- [`generation.GenerateBeamDecoderOnlyOutput`], [`~generation.GenerateBeamEncoderDecoderOutput`] or
- `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
- [`~generation.GenerateBeamDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
- `return_dict_in_generate=True` or a [`~generation.GenerateBeamEncoderDecoderOutput`] if
- `model.config.is_encoder_decoder=True`.
- """
- # 1. init beam_search values
- pad_token_id = generation_config._pad_token_tensor
- eos_token_id = generation_config._eos_token_tensor
- output_attentions = generation_config.output_attentions
- output_hidden_states = generation_config.output_hidden_states
- output_scores = generation_config.output_scores
- output_logits = generation_config.output_logits
- return_dict_in_generate = generation_config.return_dict_in_generate
- do_sample = generation_config.do_sample
- early_stopping = generation_config.early_stopping
- length_penalty = generation_config.length_penalty
- max_length = generation_config.max_length
- num_beams = generation_config.num_beams
- num_return_sequences = generation_config.num_return_sequences
- batch_size_unflattened, cur_len = input_ids.shape[:2]
- batch_size = batch_size_unflattened // num_beams
- # TODO (joao): standardize special cases
- if self.__class__.__name__ == "MoshiDepthDecoder":
- vocab_size = self.config.audio_vocab_size
- elif self.__class__.__name__ == "ImageGPTForCausalImageModeling":
- vocab_size = self.get_output_embeddings().out_features
- elif self.__class__.__name__ == "BarkSemanticModel":
- vocab_size = self.config.output_vocab_size
- else:
- vocab_size = self.config.get_text_config().vocab_size
- decoder_prompt_len = cur_len
- this_peer_finished = False
- # At each beam search step, we want to keep top K [K = (number of EOS tokens + 1) * `num_beams`] candidates
- # with the highest log-probabilities, or sample K continuations without replacement. We gather the top K
- # (as opposed to `num_beams`, or any number lower than K) so that we have at least `num_beams` sequences
- # non-finished to continue the live beam search, in case the top `num_beams` all select an EOS token.
- n_eos_tokens = eos_token_id.shape[0] if eos_token_id is not None else 0
- beams_to_keep = max(2, 1 + n_eos_tokens) * num_beams
- top_num_beam_mask = torch.cat(
- (torch.ones((num_beams), dtype=torch.bool), torch.zeros((beams_to_keep - num_beams), dtype=torch.bool)),
- dim=0,
- ).to(input_ids.device)
- model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
- # (joao) feature lost in the refactor. Probably won't implement, hurts readability with minimal gains (there
- # are newer low-memory alternatives like the offloaded cache)
- sequential = generation_config.low_memory
- if sequential:
- raise ValueError(
- "`low_memory=True` is not supported after the beam search refactor. Please check the discussion in "
- "#35802 *after the PR got merged*, and add a comment there if your questions are not yet answered."
- )
- # 2. init output tuples
- all_scores = () if (return_dict_in_generate and output_scores) else None
- raw_logits = () if (return_dict_in_generate and output_logits) else None
- beam_indices = () if (return_dict_in_generate and output_logits) else None
- decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
- cross_attentions = () if (return_dict_in_generate and output_attentions) else None
- decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
- # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
- if return_dict_in_generate and self.config.is_encoder_decoder:
- encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
- encoder_hidden_states = (
- model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
- )
- # 3. init running tensors and static-shaped placeholders
- # per batch, beam-item holding current token in loop and completed sequences
- output_fill_value = pad_token_id or eos_token_id[0] if eos_token_id is not None else -1
- running_sequences = torch.full(
- (batch_size, num_beams, max_length),
- fill_value=output_fill_value,
- dtype=torch.int64,
- device=input_ids.device,
- )
- running_sequences[:, :, :cur_len] = self._unflatten_beam_dim(input_ids, batch_size, num_beams)
- sequences = running_sequences.detach().clone()
- # per batch, beam-item score, logprobs
- # initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens
- # of the first beam are considered to avoid sampling the exact same tokens across all beams.
- running_beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
- running_beam_scores[:, 1:] = -1e9
- beam_scores = torch.full((batch_size, num_beams), fill_value=-1e9, dtype=torch.float, device=input_ids.device)
- # per batch, beam-item state bit indicating if sentence has finished.
- is_sent_finished = torch.zeros((batch_size, num_beams), dtype=torch.bool, device=input_ids.device)
- # per batch state bit indicating if there is a possibility to improve the best finished sentence.
- is_early_stop_heuristic_unsatisfied = torch.ones((batch_size, 1), dtype=torch.bool, device=input_ids.device)
- # per batch, beam-item state bit indicating if there are valid continuations.
- next_token_hits_stopping_criteria = torch.zeros(
- (batch_size, num_beams), dtype=torch.bool, device=input_ids.device
- )
- # per batch selected beam indices
- running_beam_indices = torch.full(
- (batch_size, num_beams, max_length - cur_len), fill_value=-1, dtype=torch.int32, device=input_ids.device
- )
- beam_indices = running_beam_indices.detach().clone()
- # 4. run the generation loop
- while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
- # a. Forward current tokens, obtain the logits
- flat_running_sequences = self._flatten_beam_dim(running_sequences[:, :, :cur_len])
- model_inputs = self.prepare_inputs_for_generation(flat_running_sequences, **model_kwargs)
- model_outputs = self(**model_inputs, return_dict=True)
- # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
- model_kwargs = self._update_model_kwargs_for_generation(
- model_outputs,
- model_kwargs,
- is_encoder_decoder=self.config.is_encoder_decoder,
- )
- if synced_gpus and this_peer_finished:
- continue
- # Copy is needed to avoid keeping a hanging ref
- logits = model_outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device)
- # b. Compute log probs -- get log probabilities from logits, process logits with processors (*e.g.*
- # `temperature`, ...), and add new logprobs to existing running logprobs scores.
- log_probs = nn.functional.log_softmax(logits, dim=-1)
- log_probs = logits_processor(flat_running_sequences, log_probs)
- # Store logits, attentions and hidden_states when required
- if return_dict_in_generate:
- if output_logits:
- raw_logits += (logits.clone(),)
- if return_dict_in_generate and output_scores:
- all_scores += (log_probs.clone(),)
- if output_attentions:
- decoder_attentions += (
- (model_outputs.decoder_attentions,)
- if self.config.is_encoder_decoder
- else (model_outputs.attentions,)
- )
- if self.config.is_encoder_decoder:
- cross_attentions += (model_outputs.cross_attentions,)
- if output_hidden_states:
- decoder_hidden_states += (
- (model_outputs.decoder_hidden_states,)
- if self.config.is_encoder_decoder
- else (model_outputs.hidden_states,)
- )
- # This is needed to properly delete logits which may be very large for first iteration
- # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
- del model_outputs
- log_probs = self._unflatten_beam_dim(log_probs, batch_size, num_beams)
- log_probs = log_probs + running_beam_scores[:, :, None]
- log_probs = torch.reshape(log_probs, (batch_size, num_beams * vocab_size))
- # c. Retrieve top-K continuations, i.e. select the next token (greedy or sampling) and then keep the best
- # continuations among all beams based on the accumulated scores.
- topk_log_probs, topk_running_sequences, topk_running_beam_indices = self._get_top_k_continuations(
- accumulated_log_probs=log_probs,
- running_sequences=running_sequences,
- running_beam_indices=running_beam_indices,
- cur_len=cur_len,
- decoder_prompt_len=decoder_prompt_len,
- do_sample=do_sample,
- beams_to_keep=beams_to_keep,
- num_beams=num_beams,
- vocab_size=vocab_size,
- batch_size=batch_size,
- )
- # d. Check which running sequences have finished
- next_token_hits_stopping_criteria = stopping_criteria(
- self._flatten_beam_dim(topk_running_sequences[:, :, : cur_len + 1]), # remove unfilled token indexes
- all_scores,
- )
- next_token_hits_stopping_criteria = self._unflatten_beam_dim(
- next_token_hits_stopping_criteria, batch_size, beams_to_keep
- )
- # e. Get the non-finished running `num_beams` sequences for the next generation step
- running_sequences, running_beam_scores, running_beam_indices = self._get_running_beams_for_next_iteration(
- topk_log_probs=topk_log_probs,
- topk_running_sequences=topk_running_sequences,
- topk_running_beam_indices=topk_running_beam_indices,
- next_token_hits_stopping_criteria=next_token_hits_stopping_criteria,
- num_beams=num_beams,
- )
- # f. Update the completed beams if a new high score in a finished sequence is found
- sequences, beam_scores, beam_indices, is_sent_finished = self._update_finished_beams(
- sequences=sequences,
- topk_running_sequences=topk_running_sequences,
- beam_scores=beam_scores,
- topk_log_probs=topk_log_probs,
- beam_indices=beam_indices,
- topk_running_beam_indices=topk_running_beam_indices,
- is_early_stop_heuristic_unsatisfied=is_early_stop_heuristic_unsatisfied,
- is_sent_finished=is_sent_finished,
- next_token_hits_stopping_criteria=next_token_hits_stopping_criteria,
- top_num_beam_mask=top_num_beam_mask,
- num_beams=num_beams,
- cur_len=cur_len,
- decoder_prompt_len=decoder_prompt_len,
- length_penalty=length_penalty,
- early_stopping=early_stopping,
- )
- # g. Prepare remaining data for the next iteration, including computing the stopping condition for
- # beam search as a whole (as opposed to individual beams, i.e. `stopping_criteria`)
- # pluck the cache from the beam indices that will be used in the next iteration
- # NOTE: we need to check if `self._reorder_cache` exists for special models like RAG, RecurrentGemma etc.
- if model_kwargs.get("past_key_values", None) is not None:
- beam_idx = self._flatten_beam_dim(running_beam_indices[..., cur_len - decoder_prompt_len])
- if hasattr(self, "_reorder_cache"):
- model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx)
- else:
- model_kwargs["past_key_values"].reorder_cache(beam_idx)
- cur_len = cur_len + 1
- is_early_stop_heuristic_unsatisfied = self._check_early_stop_heuristic(
- is_early_stop_heuristic_unsatisfied=is_early_stop_heuristic_unsatisfied,
- running_beam_scores=running_beam_scores,
- beam_scores=beam_scores,
- is_sent_finished=is_sent_finished,
- cur_len=cur_len,
- max_length=max_length,
- decoder_prompt_len=decoder_prompt_len,
- early_stopping=early_stopping,
- length_penalty=length_penalty,
- )
- this_peer_finished = not self._beam_search_has_unfinished_sequences(
- is_early_stop_heuristic_unsatisfied,
- is_sent_finished,
- next_token_hits_stopping_criteria,
- early_stopping,
- )
- # 5. prepare outputs
- # Take best beams for each batch (the score is sorted in descending order)
- sequences = self._flatten_beam_dim(sequences[:, :num_return_sequences, :])
- beam_scores = self._flatten_beam_dim(beam_scores[:, :num_return_sequences])
- beam_indices = self._flatten_beam_dim(beam_indices[:, :num_return_sequences, :])
- # Crop the static-shaped tensors to the actual size.
- # `beam_indices` is initialized with -1s, and is updated with the beam index of the generated token at each
- # step. We can use it to detect the generated length, which may be != `cur_len` (e.g. selected beam is from a
- # previous decoding iteration)
- max_generated_length = ((beam_indices + 1).bool()).sum(dim=1).max()
- output_length = decoder_prompt_len + max_generated_length
- sequences = sequences[:, :output_length]
- beam_indices = beam_indices[:, :max_generated_length]
- if return_dict_in_generate:
- if not output_scores:
- beam_scores = None
- if self.config.is_encoder_decoder:
- return GenerateBeamEncoderDecoderOutput(
- sequences=sequences,
- sequences_scores=beam_scores,
- scores=all_scores,
- logits=raw_logits,
- beam_indices=beam_indices,
- encoder_attentions=encoder_attentions,
- encoder_hidden_states=encoder_hidden_states,
- decoder_attentions=decoder_attentions,
- cross_attentions=cross_attentions,
- decoder_hidden_states=decoder_hidden_states,
- past_key_values=model_kwargs.get("past_key_values"),
- )
- else:
- return GenerateBeamDecoderOnlyOutput(
- sequences=sequences,
- sequences_scores=beam_scores,
- scores=all_scores,
- logits=raw_logits,
- beam_indices=beam_indices,
- attentions=decoder_attentions,
- hidden_states=decoder_hidden_states,
- past_key_values=model_kwargs.get("past_key_values"),
- )
- else:
- return sequences
- def _assisted_decoding(
- self,
- input_ids: torch.LongTensor,
- logits_processor: LogitsProcessorList,
- stopping_criteria: StoppingCriteriaList,
- generation_config: GenerationConfig,
- synced_gpus: bool = False,
- streamer: Optional["BaseStreamer"] = None,
- inputs_tensor: Optional[torch.FloatTensor] = None,
- assistant_model: Optional["PreTrainedModel"] = None,
- assistant_tokenizer: Optional["PreTrainedTokenizerBase"] = None,
- tokenizer: Optional["PreTrainedTokenizerBase"] = None,
- **model_kwargs,
- ) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
- r"""
- Generates sequences of token ids for models with a language modeling head using **greedy decoding** or
- **sample** (depending on `do_sample`), assisted by candidate sequences. Assisted generation is an example of a
- candidate decoding strategy. Can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text
- models.
- Parameters:
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- The sequence used as a prompt for the generation.
- logits_processor (`LogitsProcessorList`):
- An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
- used to modify the prediction scores of the language modeling head applied at each generation step.
- stopping_criteria (`StoppingCriteriaList`):
- An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
- used to tell if the generation loop should stop.
- generation_config ([`~generation.GenerationConfig`]):
- The generation configuration to be used as parametrization of the decoding method.
- synced_gpus (`bool`):
- Whether to continue running the while loop until max_length (needed to avoid deadlocking with
- `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
- streamer (`BaseStreamer`, *optional*):
- Streamer object that will be used to stream the generated sequences. Generated tokens are passed
- through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
- inputs_tensor (`torch.FloatTensor`, *optional*):
- The input tensor for generation. For decoder models, usually `input_ids`. For encoder-decoder models,
- the tensor that produced `model_kwargs["encoder_outputs"]`.
- assistant_model (`PreTrainedModel`, *optional*):
- The model used to assist the generation process. If not provided, the main model will be used.
- assistant_tokenizer (`PreTrainedTokenizerBase`, *optional*):
- The tokenizer used for the assistant model. If not provided, the token space is assumed to be the same.
- tokenizer (`PreTrainedTokenizerBase`, *optional*):
- The tokenizer used for the main model. If not provided, the token space is assumed to be the same.
- model_kwargs:
- Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
- If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
- Return:
- [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or
- `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
- [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
- `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if
- `model.config.is_encoder_decoder=True`.
- """
- # The cache must be dynamic for assisted generation, and the check must happen AFTER preparing cache
- if not model_kwargs["use_cache"]:
- raise ValueError("assisted generate requires `use_cache=True`")
- if generation_config.cache_implementation in ["static", "hybrid", "sliding_window"] or (
- "past_key_values" in model_kwargs
- and hasattr(model_kwargs["past_key_values"], "layers")
- and any(getattr(l, "is_compileable", False) for l in model_kwargs["past_key_values"].layers)
- ):
- raise ValueError("assisted generate is not supported with Static cache classes`")
- # Get the candidate generator, given the parameterization
- candidate_generator = self._get_candidate_generator(
- generation_config=generation_config,
- input_ids=input_ids,
- inputs_tensor=inputs_tensor,
- assistant_model=assistant_model,
- logits_processor=logits_processor,
- target_tokenizer=tokenizer,
- assistant_tokenizer=assistant_tokenizer,
- model_kwargs=model_kwargs,
- )
- # init values
- do_sample = generation_config.do_sample
- output_attentions = generation_config.output_attentions
- output_hidden_states = generation_config.output_hidden_states
- output_scores = generation_config.output_scores
- output_logits = generation_config.output_logits
- return_dict_in_generate = generation_config.return_dict_in_generate
- # init attention / hidden states / scores tuples
- scores = () if (return_dict_in_generate and output_scores) else None
- raw_logits = () if (return_dict_in_generate and output_logits) else None
- decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
- cross_attentions = () if (return_dict_in_generate and output_attentions) else None
- decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
- # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
- if return_dict_in_generate and self.config.is_encoder_decoder:
- encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
- encoder_hidden_states = (
- model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
- )
- # keep track of which sequences are already finished
- batch_size, cur_len = input_ids.shape[:2]
- if batch_size > 1:
- raise ValueError("assisted generate is only supported for batch_size = 1")
- unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
- model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
- this_peer_finished = False
- is_first_iteration = True # to preserve the same API in the output as other generation methods
- while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
- cur_len = input_ids.shape[1]
- # 1. Fetch candidate sequences from a `CandidateGenerator` and move to the correct device
- candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids)
- candidate_input_ids = candidate_input_ids.to(self.device)
- if candidate_logits is not None:
- candidate_logits = candidate_logits.to(self.device)
- candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1]
- is_done_candidate = stopping_criteria(candidate_input_ids, None)
- # 2. Use the original model to obtain the next token logits given the candidate sequence. We obtain
- # `candidate_length + 1` relevant logits from this process: in the event that all candidates are correct,
- # we use this forward pass to also pick the subsequent logits in the original model.
- # 2.1. Prepare the model inputs
- candidate_kwargs = copy.copy(model_kwargs)
- candidate_kwargs = _prepare_attention_mask(
- candidate_kwargs, candidate_input_ids.shape[1], self.config.is_encoder_decoder
- )
- candidate_kwargs = _prepare_token_type_ids(candidate_kwargs, candidate_input_ids.shape[1])
- if "cache_position" in candidate_kwargs:
- candidate_kwargs["cache_position"] = torch.cat(
- (
- candidate_kwargs["cache_position"],
- torch.arange(cur_len, cur_len + candidate_length, device=input_ids.device, dtype=torch.long),
- ),
- dim=0,
- )
- model_inputs = self.prepare_inputs_for_generation(candidate_input_ids, **candidate_kwargs)
- if "logits_to_keep" in model_inputs:
- model_inputs["logits_to_keep"] = candidate_length + 1
- # 2.2. Run a forward pass on the candidate sequence
- outputs = self(**model_inputs)
- # 2.3. Process the new logits
- # .float() is needed to retain precision for later logits manipulations
- new_logits = outputs.logits[:, -candidate_length - 1 :].to(
- dtype=torch.float32, device=input_ids.device
- ) # excludes the input prompt if present
- next_token_logits = new_logits.clone()
- if len(logits_processor) > 0:
- for i in range(candidate_length + 1):
- new_logits[:, i, :] = logits_processor(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :])
- # 3. Select the accepted tokens. There are two possible cases:
- # Case 1: `do_sample=True` and we have logits for the candidates (originally from speculative decoding)
- # 👉 Apply algorithm 1 from the speculative decoding paper (https://huggingface.co/papers/2211.17192).
- if do_sample and candidate_logits is not None:
- valid_tokens, n_matches = _speculative_sampling(
- candidate_input_ids,
- candidate_logits,
- candidate_length,
- new_logits,
- is_done_candidate,
- )
- # Case 2: all other cases (originally from assisted generation) 👉 Compare the tokens selected from the
- # original model logits with the candidate tokens. We can keep the candidate tokens until the first
- # mismatch, or until the max length is reached.
- else:
- if do_sample:
- probs = new_logits.softmax(dim=-1)
- selected_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :]
- else:
- selected_tokens = new_logits.argmax(dim=-1)
- candidate_new_tokens = candidate_input_ids[:, cur_len:]
- n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum()
- # Ensure we don't generate beyond max_len or an EOS token
- if is_done_candidate and n_matches == candidate_length:
- n_matches -= 1
- valid_tokens = selected_tokens[:, : n_matches + 1]
- # 4. Update variables according to the number of matching assistant tokens. Remember: the token generated
- # by the model after the last candidate match is also valid, as it is generated from a correct sequence.
- # Because of this last token, assisted generation search reduces to a normal greedy search/sample if there
- # is no match.
- # 4.1. Get the valid continuation, after the matching tokens
- input_ids = torch.cat((input_ids, valid_tokens), dim=-1)
- if streamer is not None:
- streamer.put(valid_tokens.cpu())
- new_cur_len = input_ids.shape[1]
- # 4.2. Discard past key values relative to unused assistant tokens
- outputs.past_key_values.crop(new_cur_len - 1)
- # 5. Update the candidate generation strategy if needed
- candidate_generator.update_candidate_strategy(input_ids, new_logits, n_matches)
- # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
- model_kwargs = self._update_model_kwargs_for_generation(
- outputs,
- model_kwargs,
- is_encoder_decoder=self.config.is_encoder_decoder,
- num_new_tokens=n_matches + 1,
- )
- if synced_gpus and this_peer_finished:
- continue
- # Store scores, attentions and hidden_states when required
- # Assistant: modified to append one tuple element per token, as in the other generation methods.
- if return_dict_in_generate:
- newly_added_length = n_matches + 1
- if output_scores:
- scores += tuple(new_logits[:, i, :] for i in range(newly_added_length))
- if output_logits:
- raw_logits += tuple(next_token_logits[:, i, :] for i in range(newly_added_length))
- newly_added_length = new_cur_len if is_first_iteration else newly_added_length
- if output_attentions:
- if self.config.is_encoder_decoder:
- cross_attentions = _split_model_outputs(
- cross_attentions, outputs.cross_attentions, cur_len, newly_added_length
- )
- decoder_attentions = _split_model_outputs(
- decoder_attentions,
- outputs.decoder_attentions,
- cur_len,
- newly_added_length,
- is_decoder_attention=True,
- )
- # some (V)LLMs have hard requirement on SDPA and thus never return attn
- elif outputs.attentions[0] is not None:
- decoder_attentions = _split_model_outputs(
- decoder_attentions,
- outputs.attentions,
- cur_len,
- newly_added_length,
- is_decoder_attention=True,
- )
- if output_hidden_states:
- if self.config.is_encoder_decoder:
- decoder_hidden_states = _split_model_outputs(
- decoder_hidden_states, outputs.decoder_hidden_states, cur_len, newly_added_length
- )
- else:
- decoder_hidden_states = _split_model_outputs(
- decoder_hidden_states, outputs.hidden_states, cur_len, newly_added_length
- )
- unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
- this_peer_finished = unfinished_sequences.max() == 0
- is_first_iteration = False
- if streamer is not None:
- streamer.end()
- if (
- hasattr(candidate_generator, "assistant_model")
- and candidate_generator.assistant_model.generation_config.num_assistant_tokens_schedule == "heuristic"
- ):
- candidate_generator.assistant_model.generation_config.num_assistant_tokens = (
- candidate_generator.num_assistant_tokens
- )
- if return_dict_in_generate:
- if self.config.is_encoder_decoder:
- return GenerateEncoderDecoderOutput(
- sequences=input_ids,
- scores=scores,
- logits=raw_logits,
- encoder_attentions=encoder_attentions,
- encoder_hidden_states=encoder_hidden_states,
- decoder_attentions=decoder_attentions,
- cross_attentions=cross_attentions,
- decoder_hidden_states=decoder_hidden_states,
- past_key_values=model_kwargs.get("past_key_values"),
- )
- else:
- return GenerateDecoderOnlyOutput(
- sequences=input_ids,
- scores=scores,
- logits=raw_logits,
- attentions=decoder_attentions,
- hidden_states=decoder_hidden_states,
- past_key_values=model_kwargs.get("past_key_values"),
- )
- else:
- return input_ids
- def _prefill_chunking(self, input_ids: torch.LongTensor, generation_config: GenerationConfig, **model_kwargs):
- # Even if we are not compiling the forward, flex is always compiled when used. With chunk prefill, we may
- # end up needing just a bit more graphs than the default (which is 8). Doing this avoids very cryptic warnings
- torch._dynamo.config.cache_size_limit = 64
- chunk_size = generation_config.prefill_chunk_size
- # Only chunk up the token just before last, so that decoding is completely performed outside this function
- # (here we simply prefill the cache)
- input_chunks = torch.split(input_ids[:, :-1], chunk_size, dim=-1)
- if "past_key_values" not in model_kwargs:
- raise ValueError("Cannot use prefill chunking without a cache")
- model_forward = self.forward
- compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config)
- if compile_forward:
- model_forward = self.get_compiled_call(generation_config.compile_config)
- attention_mask = model_kwargs.pop("attention_mask", None)
- past_length = 0
- for input_chunk in input_chunks:
- current_length = past_length + input_chunk.shape[-1]
- # Prepare inputs
- if attention_mask is not None:
- model_kwargs["attention_mask"] = attention_mask[:, :current_length]
- model_kwargs["cache_position"] = torch.arange(
- past_length, current_length, dtype=torch.long, device=input_chunk.device
- )
- model_kwargs["position_ids"] = model_kwargs["cache_position"].unsqueeze(0)
- model_inputs = self.prepare_inputs_for_generation(input_chunk, **model_kwargs)
- outputs = model_forward(**model_inputs, return_dict=True)
- model_kwargs["past_key_values"] = outputs.past_key_values
- past_length = current_length
- model_kwargs["attention_mask"] = attention_mask
- model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1
- _ = model_kwargs.pop("position_ids", None)
- return model_kwargs
- def _speculative_sampling(
- candidate_input_ids,
- candidate_logits,
- candidate_length,
- new_logits,
- is_done_candidate,
- ):
- """
- Applies sampling as in the speculative decoding paper (https://huggingface.co/papers/2211.17192, algorithm 1). Returns
- the selected tokens, as well as the number of candidate matches.
- NOTE: Unless otherwise stated, the variable names match those in the paper.
- """
- new_candidate_input_ids = candidate_input_ids[:, -candidate_length:]
- # Gets the probabilities from the logits. q_i and p_i denote the assistant and model probabilities of the tokens
- # selected by the assistant, respectively.
- q = candidate_logits.softmax(dim=-1)
- q_i = q[:, torch.arange(candidate_length), new_candidate_input_ids].squeeze(0, 1)
- p = new_logits.softmax(dim=-1)
- p_i = p[:, torch.arange(candidate_length), new_candidate_input_ids].squeeze(0, 1)
- probability_ratio = p_i / q_i
- # When probability_ratio > 1 (i.e. q_i(x) < p_i(x), or "assistant probability of the candidate token is smaller
- # than the model probability for the same token"), keep the token. Otherwise reject with p = 1 - probability_ratio
- # (= keep with p = probability_ratio). Keep all the tokens until the first rejection
- r_i = torch.rand_like(probability_ratio)
- is_accepted = r_i <= probability_ratio
- n_matches = ((~is_accepted).cumsum(dim=-1) < 1).sum() # this is `n` in algorithm 1
- # Ensure we don't generate beyond max_len or an EOS token (not in algorithm 1, but needed for correct behavior)
- if is_done_candidate and n_matches == candidate_length:
- # Output length is assumed to be `n_matches + 1`. Since we won't generate another token with the target model
- # due to acceptance on EOS we fix `n_matches`
- n_matches -= 1
- valid_tokens = new_candidate_input_ids[:, : n_matches + 1]
- else:
- # Next token selection: if there is a rejection, adjust the distribution from the main model before sampling.
- gamma = candidate_logits.shape[1]
- p_n_plus_1 = p[:, n_matches, :]
- if n_matches < gamma:
- q_n_plus_1 = q[:, n_matches, :]
- p_prime = torch.clamp((p_n_plus_1 - q_n_plus_1), min=0)
- p_prime.div_(p_prime.sum())
- else:
- p_prime = p_n_plus_1
- t = torch.multinomial(p_prime, num_samples=1).squeeze(1)[None, :]
- # The selected tokens include the matches (if any) plus the next sampled tokens
- if n_matches > 0:
- valid_tokens = torch.cat((new_candidate_input_ids[:, :n_matches], t), dim=-1)
- else:
- valid_tokens = t
- return valid_tokens, n_matches
- def _split_model_outputs(outputs, new_outputs, cur_len, added_len, is_decoder_attention=False):
- """
- Given the (decoder/cross attentions)/(decoder hidden states) for multiple generated tokens, splits it into a tuple
- where each member corresponds to a single generated token.
- """
- # Retrocompatibility: in our generation functions, the first iteration includes the attention/hidden states for the
- # prompt.
- if len(outputs) == 0:
- new_tuple = ()
- for layer in new_outputs:
- last_dim_size = cur_len if is_decoder_attention else layer.shape[-1]
- new_tuple += (layer[..., :cur_len, :last_dim_size],)
- outputs += (new_tuple,)
- # The first iteration contains the prompt + 1 generated token, let's update the length variables accordingly
- cur_len += 1
- added_len -= cur_len
- for i in range(added_len):
- new_tuple = ()
- for layer in new_outputs:
- last_dim_size = cur_len + i if is_decoder_attention else layer.shape[-1]
- new_tuple += (layer[..., i : i + 1, :last_dim_size],)
- outputs += (new_tuple,)
- return outputs
|