utils.py 205 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095309630973098309931003101310231033104310531063107310831093110311131123113311431153116311731183119312031213122312331243125312631273128312931303131313231333134313531363137313831393140314131423143314431453146314731483149315031513152315331543155315631573158315931603161316231633164316531663167316831693170317131723173317431753176317731783179318031813182318331843185318631873188318931903191319231933194319531963197319831993200320132023203320432053206320732083209321032113212321332143215321632173218321932203221322232233224322532263227322832293230323132323233323432353236323732383239324032413242324332443245324632473248324932503251325232533254325532563257325832593260326132623263326432653266326732683269327032713272327332743275327632773278327932803281328232833284328532863287328832893290329132923293329432953296329732983299330033013302330333043305330633073308330933103311331233133314331533163317331833193320332133223323332433253326332733283329333033313332333333343335333633373338333933403341334233433344334533463347334833493350335133523353335433553356335733583359336033613362336333643365336633673368336933703371337233733374337533763377337833793380338133823383338433853386338733883389339033913392339333943395339633973398339934003401340234033404340534063407340834093410341134123413341434153416341734183419342034213422342334243425342634273428342934303431343234333434343534363437343834393440344134423443344434453446344734483449345034513452345334543455345634573458345934603461346234633464346534663467346834693470347134723473347434753476347734783479348034813482348334843485348634873488348934903491349234933494349534963497349834993500350135023503350435053506350735083509351035113512351335143515351635173518351935203521352235233524352535263527352835293530353135323533353435353536353735383539354035413542354335443545354635473548354935503551355235533554355535563557355835593560356135623563356435653566356735683569357035713572357335743575357635773578357935803581358235833584358535863587358835893590359135923593359435953596359735983599360036013602360336043605360636073608360936103611361236133614361536163617361836193620362136223623362436253626362736283629363036313632363336343635363636373638363936403641364236433644364536463647364836493650365136523653365436553656365736583659366036613662366336643665366636673668366936703671367236733674367536763677367836793680368136823683368436853686368736883689369036913692369336943695369636973698369937003701370237033704370537063707370837093710371137123713371437153716371737183719372037213722372337243725372637273728372937303731373237333734373537363737373837393740374137423743374437453746374737483749375037513752375337543755375637573758375937603761376237633764376537663767376837693770377137723773377437753776377737783779378037813782378337843785378637873788378937903791379237933794379537963797379837993800380138023803380438053806380738083809381038113812381338143815381638173818381938203821382238233824382538263827382838293830383138323833383438353836383738383839384038413842384338443845384638473848384938503851385238533854385538563857385838593860386138623863
  1. # coding=utf-8
  2. # Copyright 2020 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
  3. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. import copy
  17. import inspect
  18. import os
  19. import warnings
  20. from dataclasses import dataclass
  21. from typing import TYPE_CHECKING, Any, Callable, Optional, Union
  22. import torch
  23. import torch.distributed as dist
  24. from packaging import version
  25. from torch import nn
  26. from ..cache_utils import (
  27. Cache,
  28. DynamicCache,
  29. EncoderDecoderCache,
  30. QuantizedCache,
  31. StaticCache,
  32. )
  33. from ..dynamic_module_utils import (
  34. check_python_requirements,
  35. get_cached_module_file,
  36. get_class_in_module,
  37. resolve_trust_remote_code,
  38. )
  39. from ..integrations.deepspeed import is_deepspeed_zero3_enabled
  40. from ..integrations.fsdp import is_fsdp_managed_module
  41. from ..masking_utils import create_masks_for_generate
  42. from ..pytorch_utils import isin_mps_friendly
  43. from ..tokenization_utils import ExtensionsTrie
  44. from ..utils import (
  45. ModelOutput,
  46. TransformersKwargs,
  47. is_accelerate_available,
  48. is_hqq_available,
  49. is_optimum_quanto_available,
  50. is_torchdynamo_exporting,
  51. logging,
  52. )
  53. from .candidate_generator import (
  54. AssistantVocabTranslatorCache,
  55. AssistedCandidateGenerator,
  56. AssistedCandidateGeneratorDifferentTokenizers,
  57. CandidateGenerator,
  58. EarlyExitCandidateGenerator,
  59. PromptLookupCandidateGenerator,
  60. UniversalSpeculativeDecodingGenerator,
  61. _prepare_attention_mask,
  62. _prepare_token_type_ids,
  63. )
  64. from .configuration_utils import (
  65. ALL_STATIC_CACHE_IMPLEMENTATIONS,
  66. DEPRECATED_STATIC_CACHE_IMPLEMENTATIONS,
  67. STATIC_CACHE_IMPLEMENTATIONS,
  68. GenerationConfig,
  69. GenerationMode,
  70. )
  71. from .continuous_batching import ContinuousMixin
  72. from .logits_process import (
  73. EncoderNoRepeatNGramLogitsProcessor,
  74. EncoderRepetitionPenaltyLogitsProcessor,
  75. EpsilonLogitsWarper,
  76. EtaLogitsWarper,
  77. ExponentialDecayLengthPenalty,
  78. ForcedBOSTokenLogitsProcessor,
  79. ForcedEOSTokenLogitsProcessor,
  80. InfNanRemoveLogitsProcessor,
  81. LogitNormalization,
  82. LogitsProcessorList,
  83. MinLengthLogitsProcessor,
  84. MinNewTokensLengthLogitsProcessor,
  85. MinPLogitsWarper,
  86. NoBadWordsLogitsProcessor,
  87. NoRepeatNGramLogitsProcessor,
  88. PrefixConstrainedLogitsProcessor,
  89. RepetitionPenaltyLogitsProcessor,
  90. SequenceBiasLogitsProcessor,
  91. SuppressTokensAtBeginLogitsProcessor,
  92. SuppressTokensLogitsProcessor,
  93. TemperatureLogitsWarper,
  94. TopKLogitsWarper,
  95. TopPLogitsWarper,
  96. TypicalLogitsWarper,
  97. UnbatchedClassifierFreeGuidanceLogitsProcessor,
  98. )
  99. from .stopping_criteria import (
  100. ConfidenceCriteria,
  101. EosTokenCriteria,
  102. MaxLengthCriteria,
  103. MaxTimeCriteria,
  104. StoppingCriteria,
  105. StoppingCriteriaList,
  106. StopStringCriteria,
  107. )
  108. if TYPE_CHECKING:
  109. from ..modeling_utils import PreTrainedModel
  110. from ..tokenization_utils_base import PreTrainedTokenizerBase
  111. from .streamers import BaseStreamer
  112. logger = logging.get_logger(__name__)
  113. if is_accelerate_available():
  114. from accelerate.hooks import AlignDevicesHook, add_hook_to_module
  115. # Variable names used to hold the cache at generation time
  116. ALL_CACHE_NAMES = [
  117. "past_key_values", # default
  118. "cache_params", # mamba-based models
  119. "state", # rwkv
  120. "mems", # xlnet
  121. "past_buckets_states", # reformer
  122. ]
  123. GENERATION_MODES_MAPPING = {
  124. GenerationMode.SAMPLE: "_sample",
  125. GenerationMode.GREEDY_SEARCH: "_sample",
  126. GenerationMode.BEAM_SEARCH: "_beam_search",
  127. GenerationMode.BEAM_SAMPLE: "_beam_search",
  128. GenerationMode.ASSISTED_GENERATION: "_assisted_decoding",
  129. # Deprecated methods
  130. GenerationMode.DOLA_GENERATION: "transformers-community/dola",
  131. GenerationMode.CONTRASTIVE_SEARCH: "transformers-community/contrastive-search",
  132. GenerationMode.GROUP_BEAM_SEARCH: "transformers-community/group-beam-search",
  133. GenerationMode.CONSTRAINED_BEAM_SEARCH: "transformers-community/constrained-beam-search",
  134. }
  135. @dataclass
  136. class GenerateDecoderOnlyOutput(ModelOutput):
  137. """
  138. Outputs of decoder-only generation models, when using non-beam methods.
  139. Args:
  140. sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  141. The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
  142. if all batches finished early due to the `eos_token_id`.
  143. scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`):
  144. Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
  145. at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
  146. each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
  147. logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`):
  148. Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
  149. at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
  150. each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
  151. attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`):
  152. Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
  153. `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
  154. hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`):
  155. Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
  156. `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
  157. past_key_values (`Cache`, *optional*, returned when `use_cache=True`):
  158. Returns the model cache, used to speed up decoding. Different models have a different cache format, check
  159. the model's documentation. Usually, a [`~cache_utils.Cache`] instance.
  160. """
  161. sequences: torch.LongTensor
  162. scores: Optional[tuple[torch.FloatTensor]] = None
  163. logits: Optional[tuple[torch.FloatTensor]] = None
  164. attentions: Optional[tuple[tuple[torch.FloatTensor]]] = None
  165. hidden_states: Optional[tuple[tuple[torch.FloatTensor]]] = None
  166. past_key_values: Optional[Cache] = None
  167. @dataclass
  168. class GenerateEncoderDecoderOutput(ModelOutput):
  169. """
  170. Outputs of encoder-decoder generation models, when using non-beam methods.
  171. Args:
  172. sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`):
  173. The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
  174. if all batches finished early due to the `eos_token_id`.
  175. scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`):
  176. Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
  177. at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
  178. each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
  179. logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`):
  180. Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
  181. at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
  182. each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
  183. encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`):
  184. Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads,
  185. sequence_length, sequence_length)`.
  186. encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`):
  187. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
  188. shape `(batch_size, sequence_length, hidden_size)`.
  189. decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`):
  190. Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
  191. `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
  192. cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`):
  193. Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
  194. `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
  195. decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`):
  196. Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
  197. `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
  198. past_key_values (`Cache`, *optional*, returned when `use_cache=True`):
  199. Returns the model cache, used to speed up decoding. Different models have a different cache format, check
  200. the model's documentation. Usually, a [`~cache_utils.Cache`] instance.
  201. """
  202. sequences: torch.LongTensor
  203. scores: Optional[tuple[torch.FloatTensor]] = None
  204. logits: Optional[tuple[torch.FloatTensor]] = None
  205. encoder_attentions: Optional[tuple[torch.FloatTensor]] = None
  206. encoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
  207. decoder_attentions: Optional[tuple[tuple[torch.FloatTensor]]] = None
  208. cross_attentions: Optional[tuple[tuple[torch.FloatTensor]]] = None
  209. decoder_hidden_states: Optional[tuple[tuple[torch.FloatTensor]]] = None
  210. past_key_values: Optional[Cache] = None
  211. @dataclass
  212. class GenerateBeamDecoderOnlyOutput(ModelOutput):
  213. """
  214. Outputs of decoder-only generation models, when using beam methods.
  215. Args:
  216. sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`):
  217. The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
  218. if all batches finished early due to the `eos_token_id`.
  219. sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True`):
  220. Final beam scores of the generated `sequences`.
  221. scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`):
  222. Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting
  223. of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
  224. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token),
  225. with each tensor of shape `(batch_size*num_beams, config.vocab_size)`.
  226. logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`):
  227. Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
  228. at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
  229. each generated token), with each tensor of shape `(batch_size*num_beams, config.vocab_size)`.
  230. beam_indices (`torch.LongTensor`, *optional*, returned when `output_scores=True`):
  231. Beam indices of generated token id at each generation step. `torch.LongTensor` of shape
  232. `(batch_size*num_return_sequences, sequence_length)`.
  233. attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`):
  234. Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
  235. `torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`.
  236. hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`):
  237. Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
  238. `torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`.
  239. past_key_values (`Cache`, *optional*, returned when `use_cache=True`):
  240. Returns the model cache, used to speed up decoding. Different models have a different cache format, check
  241. the model's documentation. Usually, a [`~cache_utils.Cache`] instance.
  242. """
  243. sequences: torch.LongTensor
  244. sequences_scores: Optional[torch.FloatTensor] = None
  245. scores: Optional[tuple[torch.FloatTensor]] = None
  246. logits: Optional[tuple[torch.FloatTensor]] = None
  247. beam_indices: Optional[torch.LongTensor] = None
  248. attentions: Optional[tuple[tuple[torch.FloatTensor]]] = None
  249. hidden_states: Optional[tuple[tuple[torch.FloatTensor]]] = None
  250. past_key_values: Optional[Cache] = None
  251. @dataclass
  252. class GenerateBeamEncoderDecoderOutput(ModelOutput):
  253. """
  254. Outputs of encoder-decoder generation models, when using beam methods.
  255. Args:
  256. sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`):
  257. The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
  258. if all batches finished early due to the `eos_token_id`.
  259. sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True`):
  260. Final beam scores of the generated `sequences`.
  261. scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`):
  262. Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting
  263. of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
  264. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token),
  265. with each tensor of shape `(batch_size*num_beams, config.vocab_size)`.
  266. logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`):
  267. Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
  268. at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
  269. each generated token), with each tensor of shape `(batch_size*num_beams, config.vocab_size)`.
  270. beam_indices (`torch.LongTensor`, *optional*, returned when `output_scores=True`):
  271. Beam indices of generated token id at each generation step. `torch.LongTensor` of shape
  272. `(batch_size*num_return_sequences, sequence_length)`.
  273. encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`):
  274. Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads,
  275. sequence_length, sequence_length)`.
  276. encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`):
  277. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
  278. shape `(batch_size*num_beams*num_return_sequences, sequence_length, hidden_size)`.
  279. decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`):
  280. Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
  281. `torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, num_heads, generated_length,
  282. sequence_length)`.
  283. cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`):
  284. Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
  285. `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
  286. decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`):
  287. Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
  288. `torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`.
  289. past_key_values (`Cache`, *optional*, returned when `use_cache=True`):
  290. Returns the model cache, used to speed up decoding. Different models have a different cache format, check
  291. the model's documentation. Usually, a [`~cache_utils.Cache`] instance.
  292. """
  293. sequences: torch.LongTensor
  294. sequences_scores: Optional[torch.FloatTensor] = None
  295. scores: Optional[tuple[torch.FloatTensor]] = None
  296. logits: Optional[tuple[torch.FloatTensor]] = None
  297. beam_indices: Optional[torch.LongTensor] = None
  298. encoder_attentions: Optional[tuple[torch.FloatTensor]] = None
  299. encoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
  300. decoder_attentions: Optional[tuple[tuple[torch.FloatTensor]]] = None
  301. cross_attentions: Optional[tuple[tuple[torch.FloatTensor]]] = None
  302. decoder_hidden_states: Optional[tuple[tuple[torch.FloatTensor]]] = None
  303. past_key_values: Optional[Cache] = None
  304. # TODO (joao): remove the equivalent classes and typing shortcuts below in v5
  305. # Equivalent classes (kept for retrocompatibility purposes)
  306. GreedySearchDecoderOnlyOutput = GenerateDecoderOnlyOutput
  307. ContrastiveSearchDecoderOnlyOutput = GenerateDecoderOnlyOutput
  308. SampleDecoderOnlyOutput = GenerateDecoderOnlyOutput
  309. ContrastiveSearchEncoderDecoderOutput = GenerateEncoderDecoderOutput
  310. GreedySearchEncoderDecoderOutput = GenerateEncoderDecoderOutput
  311. SampleEncoderDecoderOutput = GenerateEncoderDecoderOutput
  312. BeamSearchDecoderOnlyOutput = GenerateBeamDecoderOnlyOutput
  313. BeamSampleDecoderOnlyOutput = GenerateBeamDecoderOnlyOutput
  314. BeamSearchEncoderDecoderOutput = GenerateBeamEncoderDecoderOutput
  315. BeamSampleEncoderDecoderOutput = GenerateBeamEncoderDecoderOutput
  316. GreedySearchOutput = Union[GreedySearchEncoderDecoderOutput, GreedySearchDecoderOnlyOutput]
  317. SampleOutput = Union[SampleEncoderDecoderOutput, SampleDecoderOnlyOutput]
  318. BeamSearchOutput = Union[BeamSearchEncoderDecoderOutput, BeamSearchDecoderOnlyOutput]
  319. BeamSampleOutput = Union[BeamSampleEncoderDecoderOutput, BeamSampleDecoderOnlyOutput]
  320. ContrastiveSearchOutput = Union[ContrastiveSearchEncoderDecoderOutput, ContrastiveSearchDecoderOnlyOutput]
  321. # Typing shortcuts
  322. GenerateNonBeamOutput = Union[GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput]
  323. GenerateBeamOutput = Union[GenerateBeamDecoderOnlyOutput, GenerateBeamEncoderDecoderOutput]
  324. GenerateOutput = Union[GenerateNonBeamOutput, GenerateBeamOutput]
  325. class GenerationMixin(ContinuousMixin):
  326. """
  327. A class containing all functions for auto-regressive text generation, to be used as a mixin in model classes.
  328. Inheriting from this class causes the model to have special generation-related behavior, such as loading a
  329. `GenerationConfig` at initialization time or ensuring `generate`-related tests are run in `transformers` CI.
  330. A model class should inherit from `GenerationMixin` to enable calling methods like `generate`, or when it
  331. has defined a custom `generate` method that relies on `GenerationMixin`, directly or indirectly, which
  332. approximately shares the same interface to public methods like `generate`. Three examples:
  333. - `LlamaForCausalLM` should inherit from `GenerationMixin` to enable calling `generate` and other public
  334. methods in the mixin;
  335. - `BlipForQuestionAnswering` has a custom `generate` method that approximately shares the same interface as
  336. `GenerationMixin.generate` (it has a few extra arguments, and the same output). That function also calls
  337. `GenerationMixin.generate` indirectly, through an inner model. As such, `BlipForQuestionAnswering` should
  338. inherit from `GenerationMixin` to benefit from all generation-related automation in our codebase;
  339. - `BarkModel` has a custom `generate` method and one of its inner models calls `GenerationMixin.generate`.
  340. However, its `generate` does not share the same interface as `GenerationMixin.generate`. In this case,
  341. `BarkModel` should NOT inherit from `GenerationMixin`, as it breaks the `generate` interface.
  342. The class exposes [`~generation.GenerationMixin.generate`], which can be used for:
  343. - *greedy decoding* if `num_beams=1` and `do_sample=False`
  344. - *multinomial sampling* if `num_beams=1` and `do_sample=True`
  345. - *beam-search decoding* if `num_beams>1` and `do_sample=False`
  346. - *beam-search multinomial sampling* if `num_beams>1` and `do_sample=True`
  347. - *assisted decoding* if `assistant_model` or `prompt_lookup_num_tokens` is passed to `.generate()`
  348. To learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).
  349. """
  350. def load_custom_generate(
  351. self,
  352. pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
  353. trust_remote_code: Optional[bool] = None,
  354. **kwargs,
  355. ) -> Callable:
  356. """
  357. Loads and returns a custom generate function, given a model repo.
  358. Args:
  359. pretrained_model_name_or_path (`str` or `os.PathLike`):
  360. Can be either:
  361. - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
  362. - A path to a *directory* containing model weights saved using
  363. [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
  364. trust_remote_code (`bool`, *optional*):
  365. Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
  366. should only be set to `True` for repositories you trust and in which you have read the code, as it will
  367. execute code present on the Hub on your local machine.
  368. **kwargs:
  369. Additional keyword arguments for remote code loading.
  370. Raises:
  371. OSError: If `pretrained_model_name_or_path` does not contain a `custom_generate` subdirectory.
  372. Returns:
  373. A callable that can be used to generate text.
  374. """
  375. # Fetches the generate.py file from the model repo. If it doesn't exist, a file in `.no_exist` cache directory
  376. # is created (preventing future hub requests), and an OSError is raised.
  377. try:
  378. module = get_cached_module_file(
  379. pretrained_model_name_or_path, module_file="custom_generate/generate.py", **kwargs
  380. )
  381. except OSError:
  382. raise OSError(
  383. f"`{pretrained_model_name_or_path}` does not contain a `custom_generate` subdirectory with a "
  384. "`generate.py` file, can't load the custom generate function."
  385. )
  386. # Handle opt-in `trust_remote_code` and related exceptions
  387. is_local_code = os.path.exists(pretrained_model_name_or_path)
  388. error_message = (
  389. f"The repository `{pretrained_model_name_or_path}` contains custom generation code that will override "
  390. "the default `generate` method."
  391. )
  392. resolve_trust_remote_code(
  393. trust_remote_code,
  394. pretrained_model_name_or_path,
  395. has_local_code=is_local_code,
  396. has_remote_code=not is_local_code,
  397. error_message=error_message,
  398. )
  399. # Load the custom generate function
  400. check_python_requirements(
  401. pretrained_model_name_or_path, requirements_file="custom_generate/requirements.txt", **kwargs
  402. )
  403. custom_generate_function = get_class_in_module("generate", module)
  404. return custom_generate_function
  405. def _cache_dependant_input_preparation(
  406. self,
  407. input_ids: torch.LongTensor,
  408. inputs_embeds: Optional[torch.FloatTensor],
  409. cache_position: Optional[torch.LongTensor],
  410. ) -> tuple[torch.FloatTensor, torch.LongTensor]:
  411. """
  412. Generic cache-dependent input preparation
  413. The code is put in a separate function to allow granular unit testing
  414. as it needs a different implementation to be exportable.
  415. If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
  416. - Exception 1: when passing input_embeds, input_ids may be missing entries
  417. - Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
  418. - Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
  419. - Exception 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and
  420. generate the first token for each sequence. Later use the generated Input ids for continuation.
  421. The current implementation does not rely on ``self`` and could be
  422. a class method. It is left as a standard method to be easily rewritten.
  423. """
  424. if is_torchdynamo_exporting():
  425. return self._cache_dependant_input_preparation_exporting(input_ids, inputs_embeds, cache_position)
  426. if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4
  427. inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
  428. elif (
  429. inputs_embeds is not None # Exception 1
  430. or (cache_position[-1] >= input_ids.shape[1]) # Exception 3
  431. ):
  432. input_ids = input_ids[:, -cache_position.shape[0] :]
  433. elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
  434. input_ids = input_ids[:, cache_position]
  435. return inputs_embeds, input_ids
  436. def _cache_dependant_input_preparation_exporting(
  437. self,
  438. input_ids: torch.LongTensor,
  439. inputs_embeds: Optional[torch.FloatTensor],
  440. cache_position: Optional[torch.LongTensor],
  441. ) -> tuple[torch.FloatTensor, torch.LongTensor]:
  442. """
  443. This method implements method ``_cache_dependant_input_preparation``
  444. with :func:`torch.cond` to make it exportable with :func:`torch.export.export`.
  445. The code is put in a separate function to allow granular unit testing.
  446. """
  447. if inputs_embeds is None:
  448. input_ids = input_ids[:, cache_position]
  449. else:
  450. # This is the code we need to implemented with torch.cond.
  451. # if input_ids.shape[1] == 0:
  452. # inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
  453. # else:
  454. # if cache_position[-1] >= input_ids.shape[1]:
  455. # input_ids = input_ids[:, -cache_position.shape[0] :]
  456. # else:
  457. # if input_ids.shape[1] != cache_position.shape[0]:
  458. # input_ids = input_ids[:, cache_position]
  459. # We need to clone the outputs to avoid aliasing.
  460. def branch_1(inputs_embeds, cache_position):
  461. return inputs_embeds[:, -cache_position.shape[0] :].clone()
  462. def branch_2(input_ids, cache_position):
  463. return input_ids[:, -cache_position.shape[0] :].clone()
  464. def branch_3(input_ids, cache_position):
  465. return input_ids[:, cache_position].clone()
  466. inputs_embeds, input_ids = torch.cond(
  467. input_ids.shape[1] == 0,
  468. (
  469. lambda input_ids, inputs_embeds, cache_position: (
  470. branch_1(inputs_embeds, cache_position),
  471. input_ids.clone(),
  472. )
  473. ),
  474. (
  475. lambda input_ids, inputs_embeds, cache_position: (
  476. inputs_embeds,
  477. torch.cond(
  478. cache_position[-1] >= input_ids.shape[1],
  479. branch_2,
  480. lambda input_ids, cache_position: (
  481. torch.cond(
  482. input_ids.shape[1] != cache_position.shape[0],
  483. branch_3,
  484. (lambda input_ids, cache_position: input_ids.clone()),
  485. [input_ids, cache_position],
  486. )
  487. ),
  488. [input_ids, cache_position],
  489. ),
  490. )
  491. ),
  492. [input_ids, inputs_embeds, cache_position],
  493. )
  494. return inputs_embeds, input_ids
  495. def prepare_inputs_for_generation(
  496. self,
  497. input_ids: torch.LongTensor,
  498. past_key_values: Optional[Cache] = None,
  499. attention_mask: Optional[torch.LongTensor] = None,
  500. inputs_embeds: Optional[torch.FloatTensor] = None,
  501. cache_position: Optional[torch.LongTensor] = None,
  502. **kwargs,
  503. ):
  504. """
  505. Prepare the model inputs for generation. Notable steps include selecting the correct input key and cloning when appropriate,
  506. creating position_ids from the attention_mask when missing, slicing inputs and converting 2D attention masks to 4D for
  507. compilable caches, and finally forwarding all additional keyword arguments unchanged to the model's forward pass.
  508. See the forward pass in the model documentation for expected arguments (different models might have different
  509. requirements for e.g. `past_key_values`). This function should work as is for most LLMs.
  510. """
  511. # 1. Handle BC:
  512. model_inputs = {}
  513. model_inputs["cache_position"] = cache_position
  514. # 2. Generic cache-dependent input preparation
  515. if past_key_values is not None:
  516. model_inputs["past_key_values"] = past_key_values
  517. # TODO (joao): handle the case where cache length == input_ids length. The function below results in an
  518. # exception because we get empty input_ids after slicing. In essence, we need to roll back the cache 1
  519. # token to recompute the logits for the first token to be generated (but not all caches support roll backs)
  520. inputs_embeds, input_ids = self._cache_dependant_input_preparation(
  521. input_ids, inputs_embeds, cache_position
  522. )
  523. # 3. Prepare base model inputs
  524. input_ids_key = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
  525. # if `inputs_embeds` are passed, we only want to use them in the 1st generation step for every prompt.
  526. if not self.config.is_encoder_decoder:
  527. if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]:
  528. model_inputs[input_ids_key] = None
  529. model_inputs["inputs_embeds"] = inputs_embeds
  530. else:
  531. # `clone` calls in this function ensure a consistent stride. See #32227
  532. model_inputs[input_ids_key] = input_ids.clone(memory_format=torch.contiguous_format)
  533. model_inputs["inputs_embeds"] = None
  534. else:
  535. model_inputs[input_ids_key] = input_ids.clone(memory_format=torch.contiguous_format)
  536. # 4. Create missing `position_ids` on the fly
  537. encoder_attention_mask = attention_mask if self.config.is_encoder_decoder else None
  538. attention_mask = (
  539. kwargs.pop("decoder_attention_mask", None) if self.config.is_encoder_decoder else attention_mask
  540. )
  541. attention_mask_key = "decoder_attention_mask" if self.config.is_encoder_decoder else "attention_mask"
  542. position_ids_key = "decoder_position_ids" if self.config.is_encoder_decoder else "position_ids"
  543. if (
  544. attention_mask is not None
  545. and kwargs.get(position_ids_key) is None
  546. and position_ids_key in set(inspect.signature(self.forward).parameters.keys())
  547. ):
  548. position_ids = attention_mask.long().cumsum(-1) - 1
  549. position_ids.masked_fill_(attention_mask == 0, 1)
  550. kwargs[position_ids_key] = position_ids # placed in kwargs for further processing (see below)
  551. # 5. Slice model inputs if it's an input that should have the same length as `input_ids`
  552. for model_input_name in ["position_ids", "token_type_ids", "decoder_position_ids"]:
  553. model_input = kwargs.get(model_input_name)
  554. if model_input is not None:
  555. if past_key_values is not None:
  556. current_input_length = (
  557. model_inputs["inputs_embeds"].shape[1]
  558. if model_inputs.get("inputs_embeds") is not None
  559. else model_inputs[input_ids_key].shape[1]
  560. )
  561. model_input = model_input[:, -current_input_length:]
  562. model_input = model_input.clone(memory_format=torch.contiguous_format)
  563. model_inputs[model_input_name] = model_input
  564. # 6. Create 4D attention mask is we are using a compilable cache (important for performant compiled forward
  565. # pass)
  566. if (
  567. isinstance(past_key_values, Cache)
  568. and past_key_values.is_compileable
  569. and attention_mask is not None
  570. and attention_mask.ndim == 2
  571. ):
  572. if not self.config.is_encoder_decoder and model_inputs["inputs_embeds"] is not None:
  573. batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
  574. else:
  575. batch_size, sequence_length = model_inputs[input_ids_key].shape[:2]
  576. # Create the causal mask with fixed shape in advance, to reduce recompilations. If the function to create
  577. # the 4D causal mask exists, it should be present in the base model (XXXModel class) or in its decoder.
  578. base_model = getattr(self, self.base_model_prefix, self)
  579. decoder = base_model.get_decoder() if hasattr(base_model, "get_decoder") else None
  580. causal_mask_creation_function = getattr(
  581. base_model, "_prepare_4d_causal_attention_mask_with_cache_position", None
  582. )
  583. if causal_mask_creation_function is None and decoder is not None: # it may be in the decoder
  584. causal_mask_creation_function = getattr(
  585. decoder, "_prepare_4d_causal_attention_mask_with_cache_position", None
  586. )
  587. # If it's not defined, it means the model uses the new general mask API
  588. if causal_mask_creation_function is None: # can't be found
  589. token_type_ids = model_inputs.get("token_type_ids")
  590. position_ids = model_inputs.get(position_ids_key)
  591. # Some models may overwrite the general one
  592. causal_mask_creation_function = getattr(self, "create_masks_for_generate", create_masks_for_generate)
  593. attention_mask = causal_mask_creation_function(
  594. config=self.config,
  595. # we only need batch size, seq_length and dtype here - we don't care about the values of the embeddings
  596. input_embeds=torch.empty((batch_size, sequence_length), dtype=self.dtype),
  597. attention_mask=attention_mask,
  598. cache_position=cache_position,
  599. past_key_values=past_key_values,
  600. position_ids=position_ids,
  601. token_type_ids=token_type_ids,
  602. )
  603. else:
  604. attention_mask = causal_mask_creation_function(
  605. attention_mask,
  606. sequence_length=sequence_length,
  607. target_length=past_key_values.get_max_cache_shape(),
  608. dtype=self.dtype,
  609. cache_position=cache_position,
  610. batch_size=batch_size,
  611. config=self.config,
  612. past_key_values=past_key_values,
  613. )
  614. if attention_mask is not None:
  615. model_inputs[attention_mask_key] = attention_mask
  616. if encoder_attention_mask is not None:
  617. model_inputs["attention_mask"] = encoder_attention_mask
  618. # 7. Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
  619. for key, value in kwargs.items():
  620. if key not in model_inputs:
  621. model_inputs[key] = value
  622. # 8. Remove unexpected `generate` inputs (TODO @joao: fix trainer and examples)
  623. model_inputs.pop("labels", None)
  624. return model_inputs
  625. def _prepare_model_inputs(
  626. self,
  627. inputs: Optional[torch.Tensor] = None,
  628. bos_token_id: Optional[torch.Tensor] = None,
  629. model_kwargs: Optional[dict[str, torch.Tensor]] = None,
  630. ) -> tuple[torch.Tensor, Optional[str], dict[str, torch.Tensor]]:
  631. """
  632. This function extracts the model-specific `inputs` for generation.
  633. """
  634. # 1. retrieve all kwargs that are non-None or non-model input related.
  635. # some encoder-decoder models have different names for model and encoder
  636. if (
  637. self.config.is_encoder_decoder
  638. and hasattr(self, "encoder")
  639. and self.encoder.main_input_name != self.main_input_name
  640. ):
  641. input_name = self.encoder.main_input_name
  642. else:
  643. input_name = self.main_input_name
  644. model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None or k != input_name}
  645. # 2. check whether model_input_name is passed as kwarg
  646. # if yes and `inputs` is None use kwarg inputs
  647. inputs_kwarg = model_kwargs.pop(input_name, None)
  648. if inputs_kwarg is not None and inputs is not None:
  649. raise ValueError(
  650. f"`inputs`: {inputs}` were passed alongside {input_name} which is not allowed. "
  651. f"Make sure to either pass {inputs} or {input_name}=..."
  652. )
  653. elif inputs_kwarg is not None:
  654. inputs = inputs_kwarg
  655. # 3. In the presence of `inputs_embeds` for text models:
  656. # - decoder-only models should complain if the user attempts to pass `inputs_embeds`, but the model
  657. # doesn't have its forwarding implemented. `inputs_embeds` is kept in `model_kwargs` and can coexist with
  658. # input_ids (`inputs_embeds` will be used in the 1st generation step, as opposed to `input_ids`)
  659. # - encoder-decoder models should complain if the user attempts to pass `inputs_embeds` and `input_ids`, and
  660. # pull the former to inputs. It will be used in place of `input_ids` to get the encoder hidden states.
  661. if input_name == "input_ids" and "inputs_embeds" in model_kwargs:
  662. if model_kwargs["inputs_embeds"] is None:
  663. model_kwargs.pop("inputs_embeds")
  664. elif not self.config.is_encoder_decoder:
  665. has_inputs_embeds_forwarding = "inputs_embeds" in set(
  666. inspect.signature(self.prepare_inputs_for_generation).parameters.keys()
  667. )
  668. if not has_inputs_embeds_forwarding:
  669. raise ValueError(
  670. f"You passed `inputs_embeds` to `.generate()`, but the model class {self.__class__.__name__} "
  671. "doesn't have its forwarding implemented. See the GPT2 implementation for an example "
  672. "(https://github.com/huggingface/transformers/pull/21405), and feel free to open a PR with it!"
  673. )
  674. # In this case, `input_ids` is moved to the `model_kwargs`, so a few automations (like the creation of
  675. # the attention mask) can rely on the actual model input.
  676. model_kwargs["input_ids"] = self._maybe_initialize_input_ids_for_generation(
  677. inputs, bos_token_id, model_kwargs=model_kwargs
  678. )
  679. inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"
  680. else:
  681. if inputs is not None:
  682. raise ValueError("You passed `inputs_embeds` and `input_ids` to `.generate()`. Please pick one.")
  683. inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"
  684. # 4. if `inputs` is still None, try to create `input_ids` from BOS token
  685. inputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs)
  686. return inputs, input_name, model_kwargs
  687. def _maybe_initialize_input_ids_for_generation(
  688. self,
  689. inputs: Optional[torch.Tensor] = None,
  690. bos_token_id: Optional[torch.Tensor] = None,
  691. model_kwargs: Optional[dict[str, torch.Tensor]] = None,
  692. ) -> torch.LongTensor:
  693. """Initializes input ids for generation, if necessary."""
  694. if inputs is not None:
  695. return inputs
  696. encoder_outputs = model_kwargs.get("encoder_outputs")
  697. if self.config.is_encoder_decoder and encoder_outputs is not None:
  698. # make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding
  699. shape = encoder_outputs.last_hidden_state.size()[:-1]
  700. return torch.ones(shape, dtype=torch.long, device=self.device) * -100
  701. # If there is some tensor in `model_kwargs`, we can infer the batch size from it. This is helpful with
  702. # soft-prompting or in multimodal implementations built on top of decoder-only language models.
  703. batch_size = 1
  704. for value in model_kwargs.values():
  705. if isinstance(value, torch.Tensor):
  706. batch_size = value.shape[0]
  707. break
  708. if "inputs_embeds" in model_kwargs:
  709. return torch.ones((batch_size, 0), dtype=torch.long, device=self.device)
  710. if bos_token_id is None:
  711. raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.")
  712. return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id
  713. def _prepare_attention_mask_for_generation(
  714. self,
  715. inputs_tensor: torch.Tensor,
  716. generation_config: GenerationConfig,
  717. model_kwargs: dict[str, Any],
  718. ) -> torch.LongTensor:
  719. pad_token_id = generation_config._pad_token_tensor
  720. eos_token_id = generation_config._eos_token_tensor
  721. # `input_ids` may be present in the model kwargs, instead of being the main input (e.g. multimodal model)
  722. if "input_ids" in model_kwargs and model_kwargs["input_ids"].shape[1] > 0:
  723. inputs_tensor = model_kwargs["input_ids"]
  724. # No information for attention mask inference -> return default attention mask
  725. default_attention_mask = torch.ones(inputs_tensor.shape[:2], dtype=torch.long, device=inputs_tensor.device)
  726. if pad_token_id is None:
  727. return default_attention_mask
  728. is_input_ids = len(inputs_tensor.shape) == 2 and inputs_tensor.dtype in [torch.int, torch.long]
  729. if not is_input_ids:
  730. return default_attention_mask
  731. is_pad_token_in_inputs = (pad_token_id is not None) and (
  732. isin_mps_friendly(elements=inputs_tensor, test_elements=pad_token_id).any()
  733. )
  734. is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ~(
  735. isin_mps_friendly(elements=eos_token_id, test_elements=pad_token_id).any()
  736. )
  737. can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id
  738. attention_mask_from_padding = inputs_tensor.ne(pad_token_id).long()
  739. attention_mask = (
  740. attention_mask_from_padding * can_infer_attention_mask + default_attention_mask * ~can_infer_attention_mask
  741. )
  742. return attention_mask
  743. def _prepare_encoder_decoder_kwargs_for_generation(
  744. self,
  745. inputs_tensor: torch.Tensor,
  746. model_kwargs,
  747. model_input_name: Optional[str],
  748. generation_config: GenerationConfig,
  749. ) -> dict[str, Any]:
  750. # 1. get encoder
  751. encoder = self.get_encoder()
  752. # Compatibility with Accelerate big model inference: we need the encoder to outputs stuff on the same device
  753. # as the inputs.
  754. if hasattr(self, "hf_device_map"):
  755. if hasattr(encoder, "_hf_hook"):
  756. encoder._hf_hook.io_same_device = True
  757. else:
  758. add_hook_to_module(encoder, AlignDevicesHook(io_same_device=True))
  759. # 2. Prepare encoder args and encoder kwargs from model kwargs and generation config.
  760. irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
  761. encoder_kwargs = {
  762. argument: value
  763. for argument, value in model_kwargs.items()
  764. if not any(argument.startswith(p) for p in irrelevant_prefix)
  765. }
  766. encoder_signature = set(inspect.signature(encoder.forward).parameters)
  767. encoder_accepts_wildcard = "kwargs" in encoder_signature or "model_kwargs" in encoder_signature
  768. if not encoder_accepts_wildcard:
  769. encoder_kwargs = {
  770. argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature
  771. }
  772. encoder_kwargs["output_attentions"] = generation_config.output_attentions
  773. encoder_kwargs["output_hidden_states"] = generation_config.output_hidden_states
  774. # 3. make sure that encoder returns `ModelOutput`
  775. model_input_name = model_input_name if model_input_name is not None else self.main_input_name
  776. encoder_kwargs["return_dict"] = True
  777. encoder_kwargs[model_input_name] = inputs_tensor
  778. model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs) # type: ignore
  779. return model_kwargs
  780. def _prepare_decoder_input_ids_for_generation(
  781. self,
  782. batch_size: int,
  783. model_input_name: str,
  784. model_kwargs: dict[str, torch.Tensor],
  785. decoder_start_token_id: torch.Tensor,
  786. device: Optional[torch.device] = None,
  787. ) -> tuple[torch.LongTensor, dict[str, torch.Tensor]]:
  788. """Prepares `decoder_input_ids` for generation with encoder-decoder models"""
  789. # 1. Check whether the user has defined `decoder_input_ids` manually. To facilitate in terms of input naming,
  790. # we also allow the user to pass it under `input_ids`, if the encoder does not use it as the main input.
  791. if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
  792. decoder_input_ids = model_kwargs.pop("decoder_input_ids")
  793. elif "input_ids" in model_kwargs and model_input_name != "input_ids":
  794. decoder_input_ids = model_kwargs.pop("input_ids")
  795. else:
  796. decoder_input_ids = None
  797. # 2. `decoder_start_token_id` must have shape (batch_size, 1)
  798. if device is None:
  799. device = self.device
  800. if decoder_start_token_id.ndim == 1:
  801. if decoder_start_token_id.shape[0] != batch_size:
  802. raise ValueError(
  803. f"`decoder_start_token_id` expected to have length {batch_size} but got {decoder_start_token_id.shape[0]}"
  804. )
  805. decoder_start_token_id = decoder_start_token_id.view(-1, 1)
  806. else:
  807. decoder_start_token_id = (
  808. torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id
  809. )
  810. # 3. Encoder-decoder models expect the `decoder_input_ids` to start with a special token. Let's ensure that.
  811. # no user input -> use decoder_start_token_id as decoder_input_ids
  812. if decoder_input_ids is None:
  813. decoder_input_ids = decoder_start_token_id
  814. # exception: Donut checkpoints have task-specific decoder starts and don't expect a BOS token. Note that the
  815. # original checkpoints can't be detected through `self.__class__.__name__.lower()`, needing custom logic.
  816. # See: https://github.com/huggingface/transformers/pull/31470
  817. elif "donut" in self.__class__.__name__.lower() or (
  818. self.config.model_type == "vision-encoder-decoder" and "donut" in self.config.encoder.model_type.lower()
  819. ):
  820. pass
  821. elif self.config.model_type == "whisper":
  822. pass
  823. # user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust
  824. # decoder_attention_mask if provided)
  825. elif (decoder_input_ids[:, 0] != decoder_start_token_id[:, 0]).all().item():
  826. decoder_input_ids = torch.cat([decoder_start_token_id, decoder_input_ids], dim=-1)
  827. if "decoder_attention_mask" in model_kwargs:
  828. decoder_attention_mask = model_kwargs["decoder_attention_mask"]
  829. decoder_attention_mask = torch.cat(
  830. (torch.ones_like(decoder_attention_mask)[:, :1], decoder_attention_mask),
  831. dim=-1,
  832. )
  833. model_kwargs["decoder_attention_mask"] = decoder_attention_mask
  834. return decoder_input_ids, model_kwargs
  835. @staticmethod
  836. def _expand_inputs_for_generation(
  837. expand_size: int = 1,
  838. is_encoder_decoder: bool = False,
  839. input_ids: Optional[torch.LongTensor] = None,
  840. **model_kwargs,
  841. ) -> tuple[torch.LongTensor, dict[str, Any]]:
  842. """Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]"""
  843. # Do not call torch.repeat_interleave if expand_size is 1 because it clones
  844. # the input tensor and thus requires more memory although no change is applied
  845. if expand_size == 1:
  846. return input_ids, model_kwargs
  847. def _expand_dict_for_generation(dict_to_expand):
  848. for key in dict_to_expand:
  849. if (
  850. key != "cache_position"
  851. and dict_to_expand[key] is not None
  852. and isinstance(dict_to_expand[key], torch.Tensor)
  853. ):
  854. dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
  855. return dict_to_expand
  856. if input_ids is not None:
  857. input_ids = input_ids.repeat_interleave(expand_size, dim=0)
  858. model_kwargs = _expand_dict_for_generation(model_kwargs)
  859. if is_encoder_decoder:
  860. if model_kwargs.get("encoder_outputs") is None:
  861. raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
  862. model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"])
  863. return input_ids, model_kwargs
  864. def _update_model_kwargs_for_generation(
  865. self,
  866. outputs: ModelOutput,
  867. model_kwargs: dict[str, Any],
  868. is_encoder_decoder: bool = False,
  869. num_new_tokens: int = 1,
  870. ) -> dict[str, Any]:
  871. # update past_key_values keeping its naming used in model code
  872. for possible_cache_name in ALL_CACHE_NAMES:
  873. if possible_cache_name in outputs:
  874. # TODO (joao): remove output/input mismatch when these old models (xlnet, reformer) are deprecated
  875. if possible_cache_name in ("past_buckets_states", "mems"):
  876. cache_name = "past_key_values"
  877. else:
  878. cache_name = possible_cache_name
  879. model_kwargs[cache_name] = getattr(outputs, possible_cache_name)
  880. break
  881. # update token_type_ids with last value
  882. if "token_type_ids" in model_kwargs:
  883. token_type_ids = model_kwargs["token_type_ids"]
  884. model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
  885. if not is_encoder_decoder:
  886. # update attention mask
  887. if "attention_mask" in model_kwargs:
  888. attention_mask = model_kwargs["attention_mask"]
  889. model_kwargs["attention_mask"] = torch.cat(
  890. [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
  891. )
  892. else:
  893. # update decoder attention mask
  894. if "decoder_attention_mask" in model_kwargs:
  895. decoder_attention_mask = model_kwargs["decoder_attention_mask"]
  896. model_kwargs["decoder_attention_mask"] = torch.cat(
  897. [decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], 1))],
  898. dim=-1,
  899. )
  900. if model_kwargs.get("use_cache", True):
  901. model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens
  902. else:
  903. past_positions = model_kwargs.pop("cache_position")
  904. new_positions = torch.arange(
  905. past_positions[-1] + 1, past_positions[-1] + num_new_tokens + 1, dtype=past_positions.dtype
  906. ).to(past_positions.device)
  907. model_kwargs["cache_position"] = torch.cat((past_positions, new_positions))
  908. return model_kwargs
  909. def _get_candidate_generator(
  910. self,
  911. generation_config: GenerationConfig,
  912. input_ids: torch.LongTensor,
  913. inputs_tensor: torch.Tensor,
  914. logits_processor: LogitsProcessorList,
  915. model_kwargs: dict[str, Any],
  916. assistant_model: Optional["PreTrainedModel"] = None,
  917. target_tokenizer: Optional["PreTrainedTokenizerBase"] = None,
  918. assistant_tokenizer: Optional["PreTrainedTokenizerBase"] = None,
  919. ) -> CandidateGenerator:
  920. """
  921. Returns the candidate generator to be used in `assisted_generation`
  922. """
  923. different_tokenizers = all(v is not None for v in (assistant_model, target_tokenizer, assistant_tokenizer))
  924. if generation_config.assistant_early_exit is not None:
  925. candidate_generator = EarlyExitCandidateGenerator(
  926. input_ids=input_ids,
  927. assistant_model=self,
  928. generation_config=generation_config,
  929. model_kwargs=model_kwargs,
  930. inputs_tensor=inputs_tensor,
  931. logits_processor=logits_processor,
  932. )
  933. elif generation_config.prompt_lookup_num_tokens is not None:
  934. candidate_generator = PromptLookupCandidateGenerator(
  935. eos_token_id=generation_config._eos_token_tensor,
  936. num_output_tokens=generation_config.prompt_lookup_num_tokens,
  937. max_matching_ngram_size=generation_config.max_matching_ngram_size or 2,
  938. max_length=generation_config.max_length,
  939. logits_processor=logits_processor,
  940. vocab_size=self.config.get_text_config().vocab_size,
  941. )
  942. elif different_tokenizers:
  943. if generation_config.do_sample is True:
  944. atm_translator = AssistantVocabTranslatorCache.get_translator(
  945. target_tokenizer,
  946. assistant_tokenizer,
  947. self.config.get_text_config().vocab_size,
  948. assistant_model=assistant_model,
  949. assistant_prune_lm_head=True, # prune LM head of assistant model
  950. )
  951. # Since we prune the LM head, we cannot use the repetition penalty on the assistant model due to mismatches between token ids and logits index
  952. assistant_model.generation_config.repetition_penalty = None
  953. candidate_generator = UniversalSpeculativeDecodingGenerator(
  954. input_ids=input_ids,
  955. assistant_model=assistant_model,
  956. generation_config=generation_config,
  957. model_kwargs=model_kwargs,
  958. inputs_tensor=inputs_tensor,
  959. logits_processor=logits_processor,
  960. target_tokenizer=target_tokenizer,
  961. assistant_tokenizer=assistant_tokenizer,
  962. atm_translator=atm_translator,
  963. )
  964. elif generation_config.do_sample is False:
  965. candidate_generator = AssistedCandidateGeneratorDifferentTokenizers(
  966. input_ids=input_ids,
  967. assistant_model=assistant_model,
  968. generation_config=generation_config,
  969. model_kwargs=model_kwargs,
  970. inputs_tensor=inputs_tensor,
  971. logits_processor=logits_processor,
  972. target_tokenizer=target_tokenizer,
  973. assistant_tokenizer=assistant_tokenizer,
  974. )
  975. else:
  976. raise ValueError(
  977. f"Invalid value for `do_sample`: expected a boolean, got {type(generation_config.do_sample).__name__}"
  978. )
  979. else:
  980. candidate_generator = AssistedCandidateGenerator(
  981. input_ids=input_ids,
  982. assistant_model=assistant_model,
  983. generation_config=generation_config,
  984. model_kwargs=model_kwargs,
  985. inputs_tensor=inputs_tensor,
  986. logits_processor=logits_processor,
  987. )
  988. return candidate_generator
  989. def _get_logits_processor(
  990. self,
  991. generation_config: GenerationConfig,
  992. input_ids_seq_length: Optional[int] = None,
  993. encoder_input_ids: Optional[torch.LongTensor] = None,
  994. prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None,
  995. logits_processor: Optional[LogitsProcessorList] = None,
  996. device: Optional[str] = None,
  997. model_kwargs: Optional[dict[str, Any]] = None,
  998. negative_prompt_ids: Optional[torch.Tensor] = None,
  999. negative_prompt_attention_mask: Optional[torch.Tensor] = None,
  1000. ) -> LogitsProcessorList:
  1001. """
  1002. This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsProcessor`]
  1003. instances used to modify the scores of the language model head.
  1004. """
  1005. # instantiate processors list
  1006. processors = LogitsProcessorList()
  1007. if logits_processor is None:
  1008. logits_processor = []
  1009. if generation_config.guidance_scale is not None and generation_config.guidance_scale != 1:
  1010. processors.append(
  1011. UnbatchedClassifierFreeGuidanceLogitsProcessor(
  1012. generation_config.guidance_scale,
  1013. self,
  1014. unconditional_ids=negative_prompt_ids,
  1015. unconditional_attention_mask=negative_prompt_attention_mask,
  1016. use_cache=generation_config.use_cache,
  1017. )
  1018. )
  1019. if generation_config.sequence_bias is not None:
  1020. processors.append(SequenceBiasLogitsProcessor(sequence_bias=generation_config.sequence_bias))
  1021. if (
  1022. generation_config.encoder_repetition_penalty is not None
  1023. and generation_config.encoder_repetition_penalty != 1.0
  1024. ):
  1025. if len(encoder_input_ids.shape) == 2:
  1026. processors.append(
  1027. EncoderRepetitionPenaltyLogitsProcessor(
  1028. penalty=generation_config.encoder_repetition_penalty,
  1029. encoder_input_ids=encoder_input_ids,
  1030. )
  1031. )
  1032. else:
  1033. warnings.warn(
  1034. "Passing `encoder_repetition_penalty` requires some form of `input_ids` to be passed to "
  1035. "`generate`, ignoring the argument.",
  1036. UserWarning,
  1037. )
  1038. if generation_config.repetition_penalty is not None and generation_config.repetition_penalty != 1.0:
  1039. processors.append(RepetitionPenaltyLogitsProcessor(penalty=generation_config.repetition_penalty))
  1040. if generation_config.no_repeat_ngram_size is not None and generation_config.no_repeat_ngram_size > 0:
  1041. processors.append(NoRepeatNGramLogitsProcessor(generation_config.no_repeat_ngram_size))
  1042. if (
  1043. generation_config.encoder_no_repeat_ngram_size is not None
  1044. and generation_config.encoder_no_repeat_ngram_size > 0
  1045. ):
  1046. if len(encoder_input_ids.shape) == 2:
  1047. processors.append(
  1048. EncoderNoRepeatNGramLogitsProcessor(
  1049. generation_config.encoder_no_repeat_ngram_size,
  1050. encoder_input_ids,
  1051. )
  1052. )
  1053. else:
  1054. warnings.warn(
  1055. "Passing `encoder_no_repeat_ngram_size` requires some form of `input_ids` to be passed to "
  1056. "`generate`, ignoring the argument.",
  1057. UserWarning,
  1058. )
  1059. if generation_config.bad_words_ids is not None:
  1060. processors.append(
  1061. NoBadWordsLogitsProcessor(
  1062. generation_config.bad_words_ids,
  1063. generation_config._eos_token_tensor,
  1064. )
  1065. )
  1066. if (
  1067. generation_config.min_length is not None
  1068. and getattr(generation_config, "_eos_token_tensor", None) is not None
  1069. and generation_config.min_length > 0
  1070. ):
  1071. processors.append(
  1072. MinLengthLogitsProcessor(
  1073. generation_config.min_length,
  1074. generation_config._eos_token_tensor,
  1075. device=device,
  1076. )
  1077. )
  1078. if (
  1079. generation_config.min_new_tokens is not None
  1080. and getattr(generation_config, "_eos_token_tensor", None) is not None
  1081. and generation_config.min_new_tokens > 0
  1082. ):
  1083. processors.append(
  1084. MinNewTokensLengthLogitsProcessor(
  1085. input_ids_seq_length,
  1086. generation_config.min_new_tokens,
  1087. generation_config._eos_token_tensor,
  1088. device=device,
  1089. )
  1090. )
  1091. if prefix_allowed_tokens_fn is not None:
  1092. processors.append(
  1093. PrefixConstrainedLogitsProcessor(
  1094. prefix_allowed_tokens_fn,
  1095. generation_config.num_beams,
  1096. )
  1097. )
  1098. if generation_config.forced_bos_token_id is not None:
  1099. processors.append(
  1100. ForcedBOSTokenLogitsProcessor(
  1101. generation_config.forced_bos_token_id,
  1102. )
  1103. )
  1104. if generation_config.forced_eos_token_id is not None:
  1105. processors.append(
  1106. ForcedEOSTokenLogitsProcessor(
  1107. generation_config.max_length,
  1108. generation_config.forced_eos_token_id,
  1109. device=device,
  1110. )
  1111. )
  1112. if generation_config.remove_invalid_values is True:
  1113. processors.append(InfNanRemoveLogitsProcessor())
  1114. if generation_config.exponential_decay_length_penalty is not None:
  1115. processors.append(
  1116. ExponentialDecayLengthPenalty(
  1117. generation_config.exponential_decay_length_penalty,
  1118. generation_config._eos_token_tensor,
  1119. input_ids_seq_length,
  1120. )
  1121. )
  1122. if generation_config.suppress_tokens is not None:
  1123. processors.append(
  1124. SuppressTokensLogitsProcessor(
  1125. generation_config.suppress_tokens,
  1126. device=device,
  1127. )
  1128. )
  1129. if generation_config.begin_suppress_tokens is not None:
  1130. begin_index = input_ids_seq_length
  1131. begin_index = (
  1132. begin_index
  1133. if (input_ids_seq_length > 1 or generation_config.forced_bos_token_id is None)
  1134. else begin_index + 1
  1135. )
  1136. processors.append(
  1137. SuppressTokensAtBeginLogitsProcessor(
  1138. generation_config.begin_suppress_tokens,
  1139. begin_index,
  1140. device=device,
  1141. )
  1142. )
  1143. # TODO (joao): find a strategy to specify the order of the processors
  1144. processors = self._merge_criteria_processor_list(processors, logits_processor)
  1145. # Processors previously known as `LogitsWarpers`, only applied with sampling strategies
  1146. if generation_config.do_sample:
  1147. # In beam methods, we need to keep at least one non-eos token to explore continuations that might have a
  1148. # better score (i.e. keep len(list(generation_config._eos_token_tensor)) + 1)
  1149. if generation_config.num_beams > 1:
  1150. if isinstance(generation_config._eos_token_tensor, list):
  1151. min_tokens_to_keep = len(generation_config._eos_token_tensor) + 1
  1152. elif isinstance(generation_config._eos_token_tensor, torch.Tensor):
  1153. min_tokens_to_keep = generation_config._eos_token_tensor.shape[0] + 1
  1154. else:
  1155. min_tokens_to_keep = 2
  1156. else:
  1157. min_tokens_to_keep = 1
  1158. # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
  1159. # all samplers can be found in `generation_utils_samplers.py`
  1160. if generation_config.temperature is not None and generation_config.temperature != 1.0:
  1161. processors.append(TemperatureLogitsWarper(generation_config.temperature))
  1162. if generation_config.top_k is not None and generation_config.top_k != 0:
  1163. processors.append(
  1164. TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep)
  1165. )
  1166. if generation_config.top_p is not None and generation_config.top_p < 1.0:
  1167. processors.append(
  1168. TopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=min_tokens_to_keep)
  1169. )
  1170. if generation_config.min_p is not None:
  1171. # Applied after temperature scaling (see https://github.com/ggerganov/llama.cpp/pull/3841#issuecomment-2073826084)
  1172. processors.append(
  1173. MinPLogitsWarper(min_p=generation_config.min_p, min_tokens_to_keep=min_tokens_to_keep)
  1174. )
  1175. if generation_config.typical_p is not None and generation_config.typical_p < 1.0:
  1176. processors.append(
  1177. TypicalLogitsWarper(mass=generation_config.typical_p, min_tokens_to_keep=min_tokens_to_keep)
  1178. )
  1179. if generation_config.epsilon_cutoff is not None and 0.0 < generation_config.epsilon_cutoff < 1.0:
  1180. processors.append(
  1181. EpsilonLogitsWarper(
  1182. epsilon=generation_config.epsilon_cutoff, min_tokens_to_keep=min_tokens_to_keep
  1183. )
  1184. )
  1185. if generation_config.eta_cutoff is not None and 0.0 < generation_config.eta_cutoff < 1.0:
  1186. processors.append(
  1187. EtaLogitsWarper(
  1188. epsilon=generation_config.eta_cutoff, min_tokens_to_keep=min_tokens_to_keep, device=device
  1189. )
  1190. )
  1191. # Watermarking should be after all logits processing is finished (see #34630)
  1192. if generation_config.watermarking_config is not None:
  1193. processors.append(
  1194. generation_config.watermarking_config.construct_processor(
  1195. self.config.get_text_config().vocab_size, device
  1196. )
  1197. )
  1198. # `LogitNormalization` should always be the last logit processor, when present
  1199. if generation_config.renormalize_logits is True:
  1200. processors.append(LogitNormalization())
  1201. return processors
  1202. def _get_stopping_criteria(
  1203. self,
  1204. generation_config: GenerationConfig,
  1205. stopping_criteria: Optional[StoppingCriteriaList],
  1206. tokenizer: Optional["PreTrainedTokenizerBase"] = None,
  1207. ) -> StoppingCriteriaList:
  1208. criteria = StoppingCriteriaList()
  1209. if generation_config.max_length is not None:
  1210. max_position_embeddings = getattr(self.config, "max_position_embeddings", None)
  1211. criteria.append(
  1212. MaxLengthCriteria(
  1213. max_length=generation_config.max_length,
  1214. max_position_embeddings=max_position_embeddings,
  1215. )
  1216. )
  1217. if generation_config.max_time is not None:
  1218. criteria.append(MaxTimeCriteria(max_time=generation_config.max_time))
  1219. if generation_config.stop_strings is not None:
  1220. if tokenizer is None:
  1221. raise ValueError(
  1222. "There are one or more stop strings, either in the arguments to `generate` or in the "
  1223. "model's generation config, but we could not locate a tokenizer. When generating with "
  1224. "stop strings, you must pass the model's tokenizer to the `tokenizer` argument of `generate`."
  1225. )
  1226. criteria.append(StopStringCriteria(stop_strings=generation_config.stop_strings, tokenizer=tokenizer))
  1227. if generation_config._eos_token_tensor is not None:
  1228. criteria.append(EosTokenCriteria(eos_token_id=generation_config._eos_token_tensor))
  1229. if (
  1230. generation_config.is_assistant
  1231. and generation_config.assistant_confidence_threshold is not None
  1232. and generation_config.assistant_confidence_threshold > 0
  1233. ):
  1234. criteria.append(
  1235. ConfidenceCriteria(assistant_confidence_threshold=generation_config.assistant_confidence_threshold)
  1236. )
  1237. criteria = self._merge_criteria_processor_list(criteria, stopping_criteria)
  1238. return criteria
  1239. def _merge_criteria_processor_list(
  1240. self,
  1241. default_list: Union[LogitsProcessorList, StoppingCriteriaList],
  1242. custom_list: Union[LogitsProcessorList, StoppingCriteriaList],
  1243. ) -> Union[LogitsProcessorList, StoppingCriteriaList]:
  1244. """
  1245. Merge user-defined processors/criteria with the ones instantiated inside `generate`. In case the same
  1246. processor/criteria is present on both lists, use the user-defined one.
  1247. (Note: up to v4.49.0, this function threw an exception is the same logit processor was found twice.)
  1248. """
  1249. if len(custom_list) == 0:
  1250. return default_list
  1251. final_list = type(default_list)()
  1252. for default in default_list:
  1253. using_custom = False
  1254. for custom in custom_list:
  1255. if type(custom) is type(default):
  1256. object_type = "stopping criteria" if isinstance(custom, StoppingCriteria) else "logits processor"
  1257. logger.warning_once(
  1258. f"A custom {object_type} of type {type(custom)} has been passed to `.generate()`, but it "
  1259. f"was also created in `.generate()`, given its parameterization. The custom {type(custom)} "
  1260. f"will take precedence. Please check the docstring of {type(custom)} to see related "
  1261. "`.generate()` flags."
  1262. )
  1263. final_list.append(custom)
  1264. using_custom = True
  1265. break
  1266. if not using_custom:
  1267. final_list.append(default)
  1268. for custom in custom_list:
  1269. if custom not in final_list:
  1270. final_list.append(custom)
  1271. return final_list
  1272. def compute_transition_scores(
  1273. self,
  1274. sequences: torch.Tensor,
  1275. scores: tuple[torch.Tensor],
  1276. beam_indices: Optional[torch.Tensor] = None,
  1277. normalize_logits: bool = False,
  1278. ) -> torch.Tensor:
  1279. """
  1280. Computes the transition scores of sequences given the generation scores (and beam indices, if beam search was
  1281. used). This is a convenient method to quickly obtain the scores of the selected tokens at generation time.
  1282. Parameters:
  1283. sequences (`torch.LongTensor`):
  1284. The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or
  1285. shorter if all batches finished early due to the `eos_token_id`.
  1286. scores (`tuple(torch.FloatTensor)`):
  1287. Transition scores for each vocabulary token at each generation step. Beam transition scores consisting
  1288. of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
  1289. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token),
  1290. with each tensor of shape `(batch_size*num_beams, config.vocab_size)`.
  1291. beam_indices (`torch.LongTensor`, *optional*):
  1292. Beam indices of generated token id at each generation step. `torch.LongTensor` of shape
  1293. `(batch_size*num_return_sequences, sequence_length)`. Only required if a `num_beams>1` at
  1294. generate-time.
  1295. normalize_logits (`bool`, *optional*, defaults to `False`):
  1296. Whether to normalize the logits (which, for legacy reasons, may be unnormalized).
  1297. Return:
  1298. `torch.Tensor`: A `torch.Tensor` of shape `(batch_size*num_return_sequences, sequence_length)` containing
  1299. the transition scores (logits)
  1300. Examples:
  1301. ```python
  1302. >>> from transformers import GPT2Tokenizer, AutoModelForCausalLM
  1303. >>> import numpy as np
  1304. >>> tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
  1305. >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
  1306. >>> tokenizer.pad_token_id = tokenizer.eos_token_id
  1307. >>> inputs = tokenizer(["Today is"], return_tensors="pt")
  1308. >>> # Example 1: Print the scores for each token generated with Greedy Search
  1309. >>> outputs = model.generate(**inputs, max_new_tokens=5, return_dict_in_generate=True, output_scores=True)
  1310. >>> transition_scores = model.compute_transition_scores(
  1311. ... outputs.sequences, outputs.scores, normalize_logits=True
  1312. ... )
  1313. >>> # input_length is the length of the input prompt for decoder-only models, like the GPT family, and 1 for
  1314. >>> # encoder-decoder models, like BART or T5.
  1315. >>> input_length = 1 if model.config.is_encoder_decoder else inputs.input_ids.shape[1]
  1316. >>> generated_tokens = outputs.sequences[:, input_length:]
  1317. >>> for tok, score in zip(generated_tokens[0], transition_scores[0]):
  1318. ... # | token | token string | log probability | probability
  1319. ... print(f"| {tok:5d} | {tokenizer.decode(tok):8s} | {score.numpy():.3f} | {np.exp(score.numpy()):.2%}")
  1320. | 262 | the | -1.414 | 24.33%
  1321. | 1110 | day | -2.609 | 7.36%
  1322. | 618 | when | -2.010 | 13.40%
  1323. | 356 | we | -1.859 | 15.58%
  1324. | 460 | can | -2.508 | 8.14%
  1325. >>> # Example 2: Reconstruct the sequence scores from Beam Search
  1326. >>> outputs = model.generate(
  1327. ... **inputs,
  1328. ... max_new_tokens=5,
  1329. ... num_beams=4,
  1330. ... num_return_sequences=4,
  1331. ... return_dict_in_generate=True,
  1332. ... output_scores=True,
  1333. ... )
  1334. >>> transition_scores = model.compute_transition_scores(
  1335. ... outputs.sequences, outputs.scores, outputs.beam_indices, normalize_logits=False
  1336. ... )
  1337. >>> # If you sum the generated tokens' scores and apply the length penalty, you'll get the sequence scores.
  1338. >>> # Tip 1: recomputing the scores is only guaranteed to match with `normalize_logits=False`. Depending on the
  1339. >>> # use case, you might want to recompute it with `normalize_logits=True`.
  1340. >>> # Tip 2: the output length does NOT include the input length
  1341. >>> output_length = np.sum(transition_scores.numpy() < 0, axis=1)
  1342. >>> length_penalty = model.generation_config.length_penalty
  1343. >>> reconstructed_scores = transition_scores.sum(axis=1) / (output_length**length_penalty)
  1344. >>> print(np.allclose(outputs.sequences_scores, reconstructed_scores))
  1345. True
  1346. ```"""
  1347. # 1. In absence of `beam_indices`, we can assume that we come from e.g. greedy search, which is equivalent
  1348. # to a beam search approach were the first (and only) beam is always selected
  1349. if beam_indices is None:
  1350. beam_indices = torch.arange(scores[0].shape[0]).view(-1, 1).to(sequences.device)
  1351. beam_indices = beam_indices.expand(-1, len(scores))
  1352. # 2. reshape scores as [batch_size*vocab_size, # generation steps] with # generation steps being
  1353. # seq_len - input_length
  1354. scores = torch.stack(scores).reshape(len(scores), -1).transpose(0, 1)
  1355. # 3. Optionally normalize the logits (across the vocab dimension)
  1356. if normalize_logits:
  1357. scores = scores.reshape(-1, self.config.get_text_config().vocab_size, scores.shape[-1])
  1358. scores = torch.nn.functional.log_softmax(scores, dim=1)
  1359. scores = scores.reshape(-1, scores.shape[-1])
  1360. # 4. cut beam_indices to longest beam length
  1361. beam_indices_mask = beam_indices < 0
  1362. max_beam_length = (1 - beam_indices_mask.long()).sum(-1).max()
  1363. beam_indices = beam_indices.clone()[:, :max_beam_length]
  1364. beam_indices_mask = beam_indices_mask[:, :max_beam_length]
  1365. # 5. Set indices of beams that finished early to 0; such indices will be masked correctly afterwards
  1366. beam_indices[beam_indices_mask] = 0
  1367. # 6. multiply beam_indices with vocab size to gather correctly from scores
  1368. beam_sequence_indices = beam_indices * self.config.get_text_config().vocab_size
  1369. # 7. Define which indices contributed to scores
  1370. cut_idx = sequences.shape[-1] - max_beam_length
  1371. indices = sequences[:, cut_idx:] + beam_sequence_indices
  1372. # 8. Compute scores
  1373. transition_scores = scores.gather(0, indices)
  1374. # 9. Mask out transition_scores of beams that stopped early
  1375. transition_scores[beam_indices_mask] = 0
  1376. return transition_scores
  1377. def _validate_generation_mode(self, generation_mode, generation_config, generation_mode_kwargs):
  1378. if generation_mode == GenerationMode.BEAM_SEARCH and "streamer" in generation_mode_kwargs:
  1379. raise ValueError(
  1380. "`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1."
  1381. )
  1382. if generation_mode == GenerationMode.ASSISTED_GENERATION:
  1383. if generation_config.num_return_sequences > 1:
  1384. raise ValueError(
  1385. "num_return_sequences has to be 1 when doing assisted generate, "
  1386. f"but is {generation_config.num_return_sequences}."
  1387. )
  1388. if self._is_stateful:
  1389. # In assisted generation we need the ability to confirm whether the model would pick certain tokens,
  1390. # which is not possible with stateful models (they can't reset to a previous subset of generated text)
  1391. raise ValueError(
  1392. f"assisted generation is not supported with stateful models, such as {self.__class__.__name__}"
  1393. )
  1394. if (assistant_model := generation_mode_kwargs.get("assistant_model")) is not None:
  1395. if self.config.is_encoder_decoder and not assistant_model.config.is_encoder_decoder:
  1396. attributes_to_check = ["encoder_attention_heads", "encoder_ffn_dim", "encoder_layers"]
  1397. attributes_to_check = [attr for attr in dir(assistant_model.config) if attr in attributes_to_check]
  1398. are_equal = all(
  1399. getattr(self.config, attr) == getattr(assistant_model.config, attr) for attr in attributes_to_check
  1400. )
  1401. if not are_equal:
  1402. raise ValueError(
  1403. "The main model and the assistant don't have compatible encoder-dependent input shapes. "
  1404. "Ensure you load the assistant with the correct encoder-decoder class, e.g. `AutoModelForSpeechSeq2Seq` for Whisper."
  1405. )
  1406. doc_reference = (
  1407. "(see https://huggingface.co/docs/transformers/en/generation_strategies#universal-assisted-decoding)"
  1408. )
  1409. if self.config.get_text_config().vocab_size == assistant_model.config.get_text_config().vocab_size:
  1410. if "assistant_tokenizer" in generation_mode_kwargs:
  1411. raise ValueError(
  1412. f"`assistant_tokenizer` is not required when the main and assistant models use the same tokenizer. Please omit `assistant_tokenizer` from `generate()` {doc_reference}."
  1413. )
  1414. else:
  1415. if "tokenizer" not in generation_mode_kwargs or "assistant_tokenizer" not in generation_mode_kwargs:
  1416. raise ValueError(
  1417. f"The main and assistant models have different tokenizers. Please provide `tokenizer` and `assistant_tokenizer` to `generate()` {doc_reference}."
  1418. )
  1419. def _validate_model_kwargs(self, model_kwargs: dict[str, Any]):
  1420. """Validates model kwargs for generation. Generate argument typos will also be caught here."""
  1421. # Excludes arguments that are handled before calling any model function
  1422. if self.config.is_encoder_decoder:
  1423. for key in ["decoder_input_ids"]:
  1424. model_kwargs.pop(key, None)
  1425. unused_model_args = []
  1426. model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters)
  1427. # `kwargs`/`model_kwargs` is often used to handle optional forward pass inputs like `attention_mask`. If
  1428. # `prepare_inputs_for_generation` doesn't accept them, then a stricter check can be made ;)
  1429. if "kwargs" in model_args or "model_kwargs" in model_args:
  1430. model_args |= set(inspect.signature(self.forward).parameters)
  1431. # Encoder-Decoder models may also need Encoder arguments from `model_kwargs`
  1432. if self.config.is_encoder_decoder:
  1433. base_model = getattr(self, self.base_model_prefix, None)
  1434. # allow encoder kwargs
  1435. encoder = getattr(self, "encoder", None)
  1436. # `MusicgenForConditionalGeneration` has `text_encoder` and `audio_encoder`.
  1437. # Also, it has `base_model_prefix = "encoder_decoder"` but there is no `self.encoder_decoder`
  1438. # TODO: A better way to handle this.
  1439. if encoder is None and base_model is not None:
  1440. encoder = getattr(base_model, "encoder", None)
  1441. if encoder is not None:
  1442. encoder_model_args = set(inspect.signature(encoder.forward).parameters)
  1443. model_args |= encoder_model_args
  1444. # allow decoder kwargs
  1445. decoder = getattr(self, "decoder", None)
  1446. if decoder is None and base_model is not None:
  1447. decoder = getattr(base_model, "decoder", None)
  1448. if decoder is not None:
  1449. decoder_model_args = set(inspect.signature(decoder.forward).parameters)
  1450. model_args |= {f"decoder_{x}" for x in decoder_model_args}
  1451. # TransformersKwargs are model-agnostic attention and generation arguments such as 'output_attentions'
  1452. for key, value in model_kwargs.items():
  1453. if value is not None and key not in model_args and key not in TransformersKwargs.__optional_keys__:
  1454. unused_model_args.append(key)
  1455. if unused_model_args:
  1456. raise ValueError(
  1457. f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the"
  1458. " generate arguments will also show up in this list)"
  1459. )
  1460. def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length):
  1461. """Performs validation related to the resulting generated length"""
  1462. # 1. Max length warnings related to poor parameterization
  1463. if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20:
  1464. # 20 is the default max_length of the generation config
  1465. warnings.warn(
  1466. f"Using the model-agnostic default `max_length` (={generation_config.max_length}) to control the "
  1467. "generation length. We recommend setting `max_new_tokens` to control the maximum length of the "
  1468. "generation.",
  1469. UserWarning,
  1470. )
  1471. if input_ids_length >= generation_config.max_length:
  1472. input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
  1473. raise ValueError(
  1474. f"Input length of {input_ids_string} is {input_ids_length}, but `max_length` is set to"
  1475. f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
  1476. " increasing `max_length` or, better yet, setting `max_new_tokens`."
  1477. )
  1478. # 2. Min length warnings due to unfeasible parameter combinations
  1479. min_length_error_suffix = (
  1480. " Generation will stop at the defined maximum length. You should decrease the minimum length and/or "
  1481. "increase the maximum length."
  1482. )
  1483. if has_default_max_length:
  1484. min_length_error_suffix += (
  1485. f" Note that `max_length` is set to {generation_config.max_length}, its default value."
  1486. )
  1487. if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
  1488. warnings.warn(
  1489. f"Unfeasible length constraints: `min_length` ({generation_config.min_length}) is larger than"
  1490. f" the maximum possible length ({generation_config.max_length})." + min_length_error_suffix,
  1491. UserWarning,
  1492. )
  1493. if generation_config.min_new_tokens is not None:
  1494. min_length = generation_config.min_new_tokens + input_ids_length
  1495. if min_length > generation_config.max_length:
  1496. warnings.warn(
  1497. f"Unfeasible length constraints: `min_new_tokens` ({generation_config.min_new_tokens}), when "
  1498. f"added to the prompt length ({input_ids_length}), is larger than"
  1499. f" the maximum possible length ({generation_config.max_length})." + min_length_error_suffix,
  1500. UserWarning,
  1501. )
  1502. def _prepare_generated_length(
  1503. self,
  1504. generation_config,
  1505. has_default_max_length,
  1506. has_default_min_length,
  1507. model_input_name,
  1508. input_ids_length,
  1509. inputs_tensor,
  1510. ):
  1511. """Prepared max and min length in generation configs to avoid clashes between similar attributes"""
  1512. if generation_config.max_new_tokens is not None:
  1513. if not has_default_max_length and generation_config.max_length is not None:
  1514. logger.warning(
  1515. f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
  1516. f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
  1517. "Please refer to the documentation for more information. "
  1518. "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
  1519. )
  1520. generation_config.max_length = generation_config.max_new_tokens + input_ids_length
  1521. # if both `inputs_embeds` and `input_ids` are passed, we do not correct the length
  1522. # otherwise we need total length [inputs-embeds-len + new-tokens-len] to not go beyond indicated `max_length``
  1523. elif (
  1524. model_input_name == "inputs_embeds"
  1525. and input_ids_length != inputs_tensor.shape[1]
  1526. and not self.config.is_encoder_decoder
  1527. ):
  1528. generation_config.max_length -= inputs_tensor.shape[1]
  1529. elif has_default_max_length: # by default let's always generate 20 new tokens
  1530. if generation_config.max_length == GenerationConfig().max_length:
  1531. generation_config.max_length = generation_config.max_length + input_ids_length
  1532. max_position_embeddings = getattr(self.config, "max_position_embeddings", None)
  1533. if max_position_embeddings is not None:
  1534. generation_config.max_length = min(generation_config.max_length, max_position_embeddings)
  1535. # same for min length
  1536. if generation_config.min_new_tokens is not None:
  1537. if not has_default_min_length:
  1538. logger.warning(
  1539. f"Both `min_new_tokens` (={generation_config.min_new_tokens}) and `min_length`(="
  1540. f"{generation_config.min_length}) seem to have been set. `min_new_tokens` will take precedence. "
  1541. "Please refer to the documentation for more information. "
  1542. "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
  1543. )
  1544. generation_config.min_length = generation_config.min_new_tokens + input_ids_length
  1545. elif (
  1546. model_input_name == "inputs_embeds"
  1547. and input_ids_length != inputs_tensor.shape[1]
  1548. and not self.config.is_encoder_decoder
  1549. ):
  1550. generation_config.min_length = max(generation_config.min_length - inputs_tensor.shape[1], 0)
  1551. return generation_config
  1552. def _prepare_generation_config(
  1553. self,
  1554. generation_config: Optional[GenerationConfig],
  1555. use_model_defaults: Optional[bool] = None,
  1556. **kwargs: Any,
  1557. ) -> tuple[GenerationConfig, dict]:
  1558. """
  1559. Prepares the base generation config, then applies any generation configuration options from kwargs. This
  1560. function handles retrocompatibility with respect to configuration files.
  1561. """
  1562. # parameterization priority:
  1563. # kwargs > non-global default values in `generation_config` > `model.generation_config` > GenerationConfig()
  1564. # TODO (joao): per-model generation config classes.
  1565. using_model_generation_config = False
  1566. if generation_config is None:
  1567. # legacy: users may modify the model configuration to control generation. To trigger this legacy behavior,
  1568. # the following conditions must be met
  1569. # 1) the generation config must have been created from the model config (`_from_model_config` field);
  1570. # 2) the generation config must have seen no modification since its creation (the hash is the same);
  1571. # 3) there are non-default generation parameters in the model config.
  1572. # 4) the user must have set new generation parameters in the model config.
  1573. if (
  1574. self.generation_config._from_model_config # 1)
  1575. and self.generation_config._original_object_hash == hash(self.generation_config) # 2)
  1576. and len(self.config._get_non_default_generation_parameters()) > 0 # 3)
  1577. ):
  1578. new_generation_config = GenerationConfig.from_model_config(self.config)
  1579. if new_generation_config != self.generation_config: # 4)
  1580. warnings.warn(
  1581. "You have modified the pretrained model configuration to control generation. This is a"
  1582. " deprecated strategy to control generation and will be removed in v5."
  1583. " Please use and modify the model generation configuration (see"
  1584. " https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )",
  1585. UserWarning,
  1586. )
  1587. self.generation_config = new_generation_config
  1588. generation_config = self.generation_config
  1589. using_model_generation_config = True
  1590. # Related to #40039: prior to this PR, models with sliding window attention were forced to have
  1591. # `cache_implementation="hybrid"` (the static sliding window cache). For these models, we now want to use
  1592. # the dynamic sliding window cache by default, so we UNSET `cache_implementation` if it is a default value.
  1593. # (if we're inside this branch, then it is because we're using default values from the Hub)
  1594. if generation_config.cache_implementation == "hybrid":
  1595. generation_config.cache_implementation = None
  1596. # `torch.export.export` usually raises an exception if it is called
  1597. # with ``strict=True``. deepcopy can only be processed if ``strict=False``.
  1598. generation_config = copy.deepcopy(generation_config)
  1599. if not using_model_generation_config:
  1600. # If `generation_config` is provided:
  1601. # - `use_model_defaults`: let's fallback ALL default values to the model's generation config
  1602. # - otherwise: legacy behavior, let's just make sure we have the tokens defined
  1603. model_base_version = version.parse(version.parse(self.generation_config.transformers_version).base_version)
  1604. if use_model_defaults is True or (
  1605. use_model_defaults is None and model_base_version >= version.parse("4.50.0")
  1606. ):
  1607. modified_values = {}
  1608. global_default_generation_config = GenerationConfig()
  1609. model_generation_config = self.generation_config
  1610. # we iterate over the model's generation config: it may hold custom keys, which we'll want to copy
  1611. for key, model_gen_config_value in model_generation_config.__dict__.items():
  1612. if key.startswith("_") or key == "transformers_version": # metadata
  1613. continue
  1614. # Don't set `cache_implementation = 'hybrid'` from the model defaults, see #40135
  1615. if key == "cache_implementation" and model_generation_config.cache_implementation == "hybrid":
  1616. continue
  1617. global_default_value = getattr(global_default_generation_config, key, None)
  1618. custom_gen_config_value = getattr(generation_config, key, None)
  1619. if (
  1620. custom_gen_config_value == global_default_value
  1621. and model_gen_config_value != global_default_value
  1622. ):
  1623. modified_values[key] = model_gen_config_value
  1624. setattr(generation_config, key, model_gen_config_value)
  1625. # edge case: we may set `temperature=0.0` and `do_sample=False`, but the model defaults to
  1626. # `do_sample=True`
  1627. if generation_config.temperature == 0.0:
  1628. generation_config.do_sample = False
  1629. if use_model_defaults is None and len(modified_values) > 0:
  1630. logger.warning_once(
  1631. f"`generation_config` default values have been modified to match model-specific defaults: "
  1632. f"{modified_values}. If this is not desired, please set these values explicitly."
  1633. )
  1634. else:
  1635. if generation_config.bos_token_id is None:
  1636. generation_config.bos_token_id = self.generation_config.bos_token_id
  1637. if generation_config.eos_token_id is None:
  1638. generation_config.eos_token_id = self.generation_config.eos_token_id
  1639. if generation_config.pad_token_id is None:
  1640. generation_config.pad_token_id = self.generation_config.pad_token_id
  1641. if generation_config.decoder_start_token_id is None:
  1642. generation_config.decoder_start_token_id = self.generation_config.decoder_start_token_id
  1643. # Finally, apply any passed kwargs
  1644. model_kwargs = generation_config.update(**kwargs)
  1645. # And keep in model_kwargs variable output controls
  1646. output_attentions = generation_config.output_attentions
  1647. output_hidden_states = generation_config.output_hidden_states
  1648. model_kwargs.update({"output_attentions": output_attentions} if output_attentions else {})
  1649. model_kwargs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
  1650. return generation_config, model_kwargs
  1651. def _get_initial_cache_position(self, seq_length, device, model_kwargs):
  1652. """Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length"""
  1653. # `torch.compile`-friendly `torch.arange` from a shape -- the lines below are equivalent to `torch.arange`
  1654. if "cache_position" in model_kwargs and model_kwargs["cache_position"] is not None:
  1655. return model_kwargs
  1656. if "inputs_embeds" in model_kwargs and not self.config.is_encoder_decoder:
  1657. cache_position = torch.ones_like(model_kwargs["inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0) - 1
  1658. elif "decoder_inputs_embeds" in model_kwargs and self.config.is_encoder_decoder:
  1659. cache_position = (
  1660. torch.ones_like(model_kwargs["decoder_inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0) - 1
  1661. )
  1662. else:
  1663. cache_position = torch.ones(seq_length, dtype=torch.int64, device=device).cumsum(0) - 1
  1664. past_length = 0
  1665. if model_kwargs.get("past_key_values") is not None:
  1666. cache = model_kwargs["past_key_values"]
  1667. past_length = 0
  1668. # Support for BC tuple cache format
  1669. if isinstance(cache, tuple):
  1670. past_length = cache[0][0].shape[2]
  1671. elif hasattr(cache, "get_seq_length"):
  1672. past_length = cache.get_seq_length()
  1673. cache_position = cache_position[past_length:]
  1674. model_kwargs["cache_position"] = cache_position
  1675. return model_kwargs
  1676. def _get_cache(self, cache_implementation: str, batch_size: int, max_cache_len: int, model_kwargs) -> Cache:
  1677. """
  1678. Sets a cache for `generate`, that will persist across calls. A new cache will only be initialized a
  1679. new `generate` call requires a larger cache or uses a different batch size.
  1680. Returns the resulting cache object.
  1681. """
  1682. requires_cross_attention_cache = (
  1683. self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None
  1684. )
  1685. offload_cache = "offloaded" in cache_implementation
  1686. if hasattr(self, "_cache"):
  1687. cache_to_check = self._cache.self_attention_cache if requires_cross_attention_cache else self._cache
  1688. need_new_cache = (
  1689. not hasattr(self, "_cache")
  1690. or cache_to_check.offloading != offload_cache
  1691. or cache_to_check.max_batch_size != batch_size
  1692. or cache_to_check.max_cache_len < max_cache_len
  1693. )
  1694. if requires_cross_attention_cache and hasattr(self, "_cache"):
  1695. need_new_cache = (
  1696. need_new_cache
  1697. or self._cache.cross_attention_cache.max_cache_len != model_kwargs["encoder_outputs"][0].shape[1]
  1698. )
  1699. if need_new_cache:
  1700. self_attention_cache_kwargs = {
  1701. "config": self.config.get_text_config(decoder=True),
  1702. "max_cache_len": max_cache_len,
  1703. "offloading": offload_cache,
  1704. }
  1705. self._cache = StaticCache(**self_attention_cache_kwargs)
  1706. if requires_cross_attention_cache:
  1707. cross_attention_cache_kwargs = {
  1708. "config": self.config.get_text_config(decoder=True),
  1709. "max_cache_len": model_kwargs["encoder_outputs"][0].shape[1],
  1710. "offloading": offload_cache,
  1711. }
  1712. self._cache = EncoderDecoderCache(self._cache, StaticCache(**cross_attention_cache_kwargs))
  1713. else:
  1714. self._cache.reset()
  1715. return self._cache
  1716. @classmethod
  1717. def _supports_default_dynamic_cache(cls) -> bool:
  1718. """
  1719. Return `True` if current model can use a `DynamicCache` instance when initializing the `past_key_values`.
  1720. This adds exception for some models like `Mamba` models which use their own caches
  1721. and do not need to initialize the Cache in advance in order to save memory (because no back and forth
  1722. `to_legacy_cache` and `from_legacy_cache` will be performed for mamba-based models).
  1723. """
  1724. # NOTE: remove xlnet/reformer when the models are deprecated, non-standard model architecture/cache name
  1725. return not cls._is_stateful and all(
  1726. special_model_name not in cls.__name__.lower()
  1727. for special_model_name in [
  1728. "reformer",
  1729. "minimax",
  1730. "xlnet",
  1731. "lfm2",
  1732. "lfm2-vl",
  1733. ]
  1734. )
  1735. def _prepare_cache_for_generation(
  1736. self,
  1737. generation_config: GenerationConfig,
  1738. model_kwargs: dict,
  1739. generation_mode: GenerationMode,
  1740. batch_size: int,
  1741. max_cache_length: int,
  1742. ) -> bool:
  1743. """
  1744. Prepares the cache for generation (if applicable), given `generate`'s parameterization. If a cache is
  1745. instantiated, writes it to `model_kwargs`, under the name expected by the model.
  1746. """
  1747. is_hybrid_cache = any(class_name in self.__class__.__name__.lower() for class_name in ["mamba", "falconh1"])
  1748. cache_name = "past_key_values" if not is_hybrid_cache else "cache_params"
  1749. requires_cross_attention_cache = (
  1750. self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None
  1751. )
  1752. # Quick escape route 1: if the user specifies a cache, we only need to:
  1753. # a) check for conflicting `generate` arguments
  1754. # b) convert to the new cache format (if the user passes a legacy cache and model supports it)
  1755. user_defined_cache = model_kwargs.get(cache_name)
  1756. if user_defined_cache is not None:
  1757. if generation_config.cache_implementation is not None:
  1758. raise ValueError(
  1759. f"Passing both `cache_implementation` (used to initialize certain caches) and `{cache_name}` (a "
  1760. "Cache object) is unsupported. Please use only one of the two."
  1761. )
  1762. if isinstance(user_defined_cache, tuple) and self._supports_default_dynamic_cache():
  1763. logger.warning_once(
  1764. "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
  1765. "You should pass an instance of `Cache` instead."
  1766. )
  1767. model_kwargs[cache_name] = (
  1768. DynamicCache.from_legacy_cache(user_defined_cache)
  1769. if not requires_cross_attention_cache
  1770. else EncoderDecoderCache.from_legacy_cache(user_defined_cache)
  1771. )
  1772. return
  1773. # Quick escape route 2: if the user specifies no cache is to be used. (conflicting arguments are handled in
  1774. # `generation_config.validate()`)
  1775. if generation_config.use_cache is False:
  1776. return
  1777. # Quick escape route 3: model that only supports legacy caches or models that supply it in
  1778. # `prepare_inputs_for_generation` (mamba, zamba, ...)
  1779. if not self._supports_default_dynamic_cache():
  1780. if generation_config.cache_implementation is not None:
  1781. logger.warning_once(
  1782. "This model does not support `Cache` instances. `cache_implementation` (set to "
  1783. f"{generation_config.cache_implementation}) will be ignored.",
  1784. )
  1785. return
  1786. # Otherwise we NEED to prepare a cache, based on `generation_config.cache_implementation`
  1787. # TODO(joao): support static caches in assisted generation. assisted generation needs to roll back caches,
  1788. # which is only supported in dynamic caches atm
  1789. if (
  1790. generation_mode == GenerationMode.ASSISTED_GENERATION
  1791. and generation_config.cache_implementation is not None
  1792. ):
  1793. logger.warning_once(
  1794. "An assistant model is provided, using a dynamic cache instead of a cache of type="
  1795. f"'{generation_config.cache_implementation}'."
  1796. )
  1797. generation_config.cache_implementation = None
  1798. # Assisted decoding and contrastive search require cache rollback, which is incompatible with sliding layers.
  1799. # To handle this, we skip passing the model config to DynamicCache (forcing a full-layer cache).
  1800. # The "dynamic_full" option is a shortcut for generate() users to avoid sliding layers on their own.
  1801. if (
  1802. generation_mode in (GenerationMode.ASSISTED_GENERATION, GenerationMode.CONTRASTIVE_SEARCH)
  1803. or generation_config.cache_implementation == "dynamic_full"
  1804. ):
  1805. dynamic_cache_kwargs = {}
  1806. else:
  1807. dynamic_cache_kwargs = {"config": self.config.get_text_config(decoder=True)}
  1808. if generation_config.cache_implementation is not None:
  1809. if generation_config.cache_implementation in ALL_STATIC_CACHE_IMPLEMENTATIONS:
  1810. if generation_config.cache_implementation in DEPRECATED_STATIC_CACHE_IMPLEMENTATIONS:
  1811. logger.warning_once(
  1812. f"Using `cache_implementation='{generation_config.cache_implementation}' is deprecated. "
  1813. f"Please only use one of {STATIC_CACHE_IMPLEMENTATIONS}, and the layer structure will be "
  1814. "inferred automatically."
  1815. )
  1816. model_kwargs[cache_name] = self._get_cache(
  1817. cache_implementation=generation_config.cache_implementation,
  1818. batch_size=max(generation_config.num_beams, generation_config.num_return_sequences) * batch_size,
  1819. max_cache_len=max_cache_length,
  1820. model_kwargs=model_kwargs,
  1821. )
  1822. elif generation_config.cache_implementation == "quantized":
  1823. if self.config.is_encoder_decoder or not self._supports_default_dynamic_cache():
  1824. raise ValueError(
  1825. "This model does not support the quantized cache. If you want your model to support quantized "
  1826. "cache, please open an issue and tag @zucchini-nlp."
  1827. )
  1828. cache_config = generation_config.cache_config if generation_config.cache_config is not None else {}
  1829. # Add the config if it was not provided, as it's a required argument
  1830. if "config" not in cache_config:
  1831. cache_config["config"] = self.config.get_text_config()
  1832. # Pop the backend from the config (defaults to quanto if not defined)
  1833. backend = cache_config.pop("backend", "quanto")
  1834. if backend == "quanto" and not is_optimum_quanto_available():
  1835. raise ImportError(
  1836. "You need to install optimum-quanto in order to use KV cache quantization with optimum-quanto "
  1837. "backend. Please install it via with `pip install optimum-quanto`"
  1838. )
  1839. elif backend == "HQQ" and not is_hqq_available():
  1840. raise ImportError(
  1841. "You need to install `HQQ` in order to use KV cache quantization with HQQ backend. "
  1842. "Please install it via with `pip install hqq`"
  1843. )
  1844. model_kwargs[cache_name] = QuantizedCache(backend=backend, **cache_config)
  1845. elif generation_config.cache_implementation == "offloaded":
  1846. model_kwargs[cache_name] = DynamicCache(**dynamic_cache_kwargs, offloading=True)
  1847. elif "dynamic" in generation_config.cache_implementation:
  1848. model_kwargs[cache_name] = DynamicCache(**dynamic_cache_kwargs)
  1849. # Use DynamicCache instance by default. This will avoid back and forth from legacy format that
  1850. # keeps copying the cache thus using much more memory
  1851. # TODO (joao): remove this `else` when we remove the last traces of the legacy cache format (v4.58.0, search
  1852. # for `instance(past_key_values, Cache)` as well). In general, if `cache_implementation` is unset, cache
  1853. # initialization should happen inside the model at prefill time.
  1854. else:
  1855. model_kwargs[cache_name] = DynamicCache(**dynamic_cache_kwargs)
  1856. # TODO (joao): this logic is incomplete, e.g. `offloaded` should apply to both caches. Refactor this function
  1857. # to correctly pass parameterization to both caches.
  1858. if requires_cross_attention_cache and not isinstance(model_kwargs[cache_name], EncoderDecoderCache):
  1859. model_kwargs[cache_name] = EncoderDecoderCache(
  1860. model_kwargs[cache_name], # self-attention cache
  1861. DynamicCache(**dynamic_cache_kwargs), # cross-attention cache
  1862. )
  1863. def _supports_logits_to_keep(self) -> bool:
  1864. """
  1865. Return True if the current model supports the keyword argument `logits_to_keep` in forward()
  1866. to save memory. Checking it in this way allows to avoid using a new model attribute.
  1867. """
  1868. return "logits_to_keep" in set(inspect.signature(self.forward).parameters.keys())
  1869. def _prepare_special_tokens(
  1870. self,
  1871. generation_config: GenerationConfig,
  1872. kwargs_has_attention_mask: Optional[bool] = None,
  1873. device: Optional[Union[torch.device, str]] = None,
  1874. ):
  1875. """
  1876. Prepares the special tokens for generation, overwriting the generation config with their processed versions
  1877. converted to tensor.
  1878. Note that `generation_config` is changed in place and stops being serializable after this method is called.
  1879. That is no problem if called within `generate` (`generation_config` is a local copy that doesn't leave the
  1880. function). However, if called outside `generate`, consider creating a copy of `generation_config` first.
  1881. """
  1882. # Convert special tokens to tensors
  1883. def _tensor_or_none(token, device=None):
  1884. if token is None:
  1885. return token
  1886. device = device if device is not None else self.device
  1887. if isinstance(token, torch.Tensor):
  1888. return token.to(device)
  1889. return torch.tensor(token, device=device, dtype=torch.long)
  1890. bos_token_tensor = _tensor_or_none(generation_config.bos_token_id, device=device)
  1891. eos_token_tensor = _tensor_or_none(generation_config.eos_token_id, device=device)
  1892. pad_token_tensor = _tensor_or_none(generation_config.pad_token_id, device=device)
  1893. decoder_start_token_tensor = _tensor_or_none(generation_config.decoder_start_token_id, device=device)
  1894. # for BC we also try to get `decoder_start_token_id` or `bos_token_id` (#30892)
  1895. if self.config.is_encoder_decoder:
  1896. decoder_start_token_tensor = (
  1897. decoder_start_token_tensor if decoder_start_token_tensor is not None else bos_token_tensor
  1898. )
  1899. # We can have more than one eos token. Always treat it as a 1D tensor (when it exists).
  1900. if eos_token_tensor is not None and eos_token_tensor.ndim == 0:
  1901. eos_token_tensor = eos_token_tensor.unsqueeze(0)
  1902. # Set pad token if unset (and there are conditions to do so)
  1903. if pad_token_tensor is None and eos_token_tensor is not None:
  1904. if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask:
  1905. logger.warning(
  1906. "The attention mask and the pad token id were not set. As a consequence, you may observe "
  1907. "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
  1908. )
  1909. pad_token_tensor = eos_token_tensor[0]
  1910. logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_tensor} for open-end generation.")
  1911. # Sanity checks/warnings
  1912. if self.config.is_encoder_decoder and decoder_start_token_tensor is None:
  1913. raise ValueError(
  1914. "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
  1915. )
  1916. if (
  1917. eos_token_tensor is not None
  1918. and isin_mps_friendly(elements=eos_token_tensor, test_elements=pad_token_tensor).any()
  1919. ):
  1920. if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask:
  1921. logger.warning_once(
  1922. "The attention mask is not set and cannot be inferred from input because pad token is same as "
  1923. "eos token. As a consequence, you may observe unexpected behavior. Please pass your input's "
  1924. "`attention_mask` to obtain reliable results."
  1925. )
  1926. if eos_token_tensor is not None and (
  1927. torch.is_floating_point(eos_token_tensor) or (eos_token_tensor < 0).any()
  1928. ):
  1929. logger.warning(
  1930. f"`eos_token_id` should consist of positive integers, but is {eos_token_tensor}. Your generation "
  1931. "will not stop until the maximum length is reached. Depending on other flags, it may even crash."
  1932. )
  1933. # Update generation config with the updated special tokens tensors
  1934. # NOTE: this must be written into a different attribute name than the one holding the original special tokens
  1935. # (in their non-tensor form), in order to enable end-to-end compilation. See
  1936. # https://pytorch.org/docs/stable/torch.compiler_cudagraph_trees.html#limitations
  1937. generation_config._bos_token_tensor = bos_token_tensor
  1938. generation_config._eos_token_tensor = eos_token_tensor
  1939. generation_config._pad_token_tensor = pad_token_tensor
  1940. generation_config._decoder_start_token_tensor = decoder_start_token_tensor
  1941. def _valid_auto_compile_criteria(self, model_kwargs: dict[str, Any], generation_config: GenerationConfig) -> bool:
  1942. """
  1943. Determines whether to trigger auto-compilation of the model's forward pass at generation time.
  1944. """
  1945. # Override: honor `disable_compile` flag
  1946. if generation_config.disable_compile:
  1947. return False
  1948. # Base logic
  1949. valid_hardware = self.device.type == "cuda" or bool(
  1950. generation_config.compile_config is not None and generation_config.compile_config._compile_all_devices
  1951. )
  1952. using_compilable_cache = (
  1953. isinstance(model_kwargs.get("past_key_values"), Cache) and model_kwargs["past_key_values"].is_compileable
  1954. )
  1955. can_compile = valid_hardware and using_compilable_cache
  1956. # Exception 1: Some quantization methods do not support compilation
  1957. if getattr(self, "hf_quantizer", None) is not None:
  1958. can_compile &= self.hf_quantizer.is_compileable
  1959. if hasattr(self, "hf_device_map"):
  1960. all_model_devices = set(self.hf_device_map.values())
  1961. # Exception 2: Don't compile if the model is using CPU offload (as of April 2025, this results in a crash)
  1962. has_cpu_offload = "cpu" in all_model_devices and len(all_model_devices) > 1
  1963. can_compile &= not has_cpu_offload
  1964. # Exception 3: Disk offload is not supported for compilation
  1965. has_disk_offload = "disk" in all_model_devices
  1966. can_compile &= not has_disk_offload
  1967. # Finally: if the user has manually specified compilation options, but compilation is not possible, let's warn
  1968. # them
  1969. if generation_config.compile_config is not None and not can_compile:
  1970. logger.warning_once(
  1971. "You have set `compile_config`, but we are unable to meet the criteria for compilation. Compilation "
  1972. "will be skipped."
  1973. )
  1974. return can_compile
  1975. def _get_deprecated_gen_repo(
  1976. self,
  1977. generation_mode: GenerationMode,
  1978. trust_remote_code: bool,
  1979. custom_generate: Optional[str] = None,
  1980. ) -> Optional[str]:
  1981. """
  1982. Returns the Hub repo for a deprecated generation mode, if any.
  1983. """
  1984. if custom_generate is not None or "/" not in (repo := GENERATION_MODES_MAPPING[generation_mode]):
  1985. return None
  1986. logger.warning_once(
  1987. f"{generation_mode.name.replace('_', ' ').title()} was moved to a `custom_generate` repo: https://hf.co/{repo}. "
  1988. f"To prevent loss of backward compatibility, add `custom_generate='{repo}'` "
  1989. "to your `generate` call before v4.62.0."
  1990. )
  1991. if not trust_remote_code:
  1992. raise ValueError(
  1993. f"{generation_mode.name.replace('_', ' ').title()} requires `trust_remote_code=True` in your `generate` call, "
  1994. f"since it loads https://hf.co/{repo}."
  1995. )
  1996. return repo
  1997. def _extract_generation_mode_kwargs(
  1998. self,
  1999. custom_generate,
  2000. kwargs,
  2001. synced_gpus,
  2002. assistant_model,
  2003. streamer,
  2004. ) -> dict[str, Any]:
  2005. """
  2006. Extracts and returns the generation mode related keyword arguments from the provided kwargs.
  2007. """
  2008. generation_mode_kwargs = {
  2009. "tokenizer": kwargs.pop("tokenizer", None),
  2010. "assistant_tokenizer": kwargs.pop("assistant_tokenizer", None),
  2011. "assistant_model": assistant_model,
  2012. "streamer": streamer,
  2013. }
  2014. generation_mode_kwargs["synced_gpus"] = (
  2015. (is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)) and dist.get_world_size() > 1
  2016. if synced_gpus is None
  2017. else synced_gpus
  2018. )
  2019. generation_mode_kwargs = {k: v for k, v in generation_mode_kwargs.items() if v is not None}
  2020. # Custom_generate callables can have their own set of arguments
  2021. # To extract them, we compare the signature with the standard _sample method
  2022. if isinstance(custom_generate, Callable):
  2023. usual_mode_kwargs = inspect.signature(GenerationMixin._sample).parameters.keys()
  2024. custom_generate_kwargs = inspect.signature(custom_generate).parameters.keys()
  2025. new_custom_keys = custom_generate_kwargs - usual_mode_kwargs
  2026. generation_mode_kwargs = {k: kwargs.pop(k) for k in new_custom_keys if k in kwargs}
  2027. return generation_mode_kwargs
  2028. @torch.no_grad()
  2029. def generate(
  2030. self,
  2031. inputs: Optional[torch.Tensor] = None,
  2032. generation_config: Optional[GenerationConfig] = None,
  2033. logits_processor: Optional[LogitsProcessorList] = None,
  2034. stopping_criteria: Optional[StoppingCriteriaList] = None,
  2035. prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None,
  2036. synced_gpus: Optional[bool] = None,
  2037. assistant_model: Optional["PreTrainedModel"] = None,
  2038. streamer: Optional["BaseStreamer"] = None,
  2039. negative_prompt_ids: Optional[torch.Tensor] = None,
  2040. negative_prompt_attention_mask: Optional[torch.Tensor] = None,
  2041. use_model_defaults: Optional[bool] = None,
  2042. custom_generate: Optional[Union[str, Callable]] = None,
  2043. **kwargs,
  2044. ) -> Union[GenerateOutput, torch.LongTensor]:
  2045. r"""
  2046. Generates sequences of token ids for models with a language modeling head.
  2047. <Tip warning={true}>
  2048. Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
  2049. model's default generation configuration. You can override any `generation_config` by passing the corresponding
  2050. parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`.
  2051. For an overview of generation strategies and code examples, check out the [following
  2052. guide](../generation_strategies).
  2053. </Tip>
  2054. Parameters:
  2055. inputs (`torch.Tensor` of varying shape depending on the modality, *optional*):
  2056. The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the
  2057. method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs`
  2058. should be in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of
  2059. `input_ids`, `input_values`, `input_features`, or `pixel_values`.
  2060. generation_config ([`~generation.GenerationConfig`], *optional*):
  2061. The generation configuration to be used as base parametrization for the generation call. `**kwargs`
  2062. passed to generate matching the attributes of `generation_config` will override them. If
  2063. `generation_config` is not provided, the default will be used, which has the following loading
  2064. priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
  2065. configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
  2066. default values, whose documentation should be checked to parameterize generation.
  2067. logits_processor (`LogitsProcessorList`, *optional*):
  2068. Custom logits processors that complement the default logits processors built from arguments and
  2069. generation config. If a logit processor is passed that is already created with the arguments or a
  2070. generation config an error is thrown. This feature is intended for advanced users.
  2071. stopping_criteria (`StoppingCriteriaList`, *optional*):
  2072. Custom stopping criteria that complements the default stopping criteria built from arguments and a
  2073. generation config. If a stopping criteria is passed that is already created with the arguments or a
  2074. generation config an error is thrown. If your stopping criteria depends on the `scores` input, make
  2075. sure you pass `return_dict_in_generate=True, output_scores=True` to `generate`. This feature is
  2076. intended for advanced users.
  2077. prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], list[int]]`, *optional*):
  2078. If provided, this function constraints the beam search to allowed tokens only at each step. If not
  2079. provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and
  2080. `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned
  2081. on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful
  2082. for constrained generation conditioned on the prefix, as described in [Autoregressive Entity
  2083. Retrieval](https://huggingface.co/papers/2010.00904).
  2084. synced_gpus (`bool`, *optional*):
  2085. Whether to continue running the while loop until max_length. Unless overridden, this flag will be set
  2086. to `True` if using `FullyShardedDataParallel` or DeepSpeed ZeRO Stage 3 with multiple GPUs to avoid
  2087. deadlocking if one GPU finishes generating before other GPUs. Otherwise, defaults to `False`.
  2088. assistant_model (`PreTrainedModel`, *optional*):
  2089. An assistant model that can be used to accelerate generation. The assistant model must have the exact
  2090. same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistant model
  2091. is much faster than running generation with the model you're calling generate from. As such, the
  2092. assistant model should be much smaller.
  2093. streamer (`BaseStreamer`, *optional*):
  2094. Streamer object that will be used to stream the generated sequences. Generated tokens are passed
  2095. through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
  2096. negative_prompt_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  2097. The negative prompt needed for some processors such as CFG. The batch size must match the input batch
  2098. size. This is an experimental feature, subject to breaking API changes in future versions.
  2099. negative_prompt_attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  2100. Attention_mask for `negative_prompt_ids`.
  2101. use_model_defaults (`bool`, *optional*):
  2102. When it is `True`, unset parameters in `generation_config` will be set to the model-specific default
  2103. generation configuration (`model.generation_config`), as opposed to the global defaults
  2104. (`GenerationConfig()`). If unset, models saved starting from `v4.50` will consider this flag to be
  2105. `True`.
  2106. custom_generate (`str` or `Callable`, *optional*):
  2107. One of the following:
  2108. - `str` (Hugging Face Hub repository name): runs the custom `generate` function defined at
  2109. `custom_generate/generate.py` in that repository instead of the standard `generate` method. The
  2110. repository fully replaces the generation logic, and the return type may differ.
  2111. - `str` (local repository path): same as above but from a local path, `trust_remote_code` not required.
  2112. - `Callable`: `generate` will perform the usual input preparation steps, then call the provided callable to
  2113. run the decoding loop.
  2114. For more information, see [the docs](../../generation_strategies#custom-generation-methods).
  2115. kwargs (`dict[str, Any]`, *optional*):
  2116. Ad hoc parametrization of `generation_config` and/or additional model-specific kwargs that will be
  2117. forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
  2118. specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.
  2119. Return:
  2120. [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
  2121. or when `config.return_dict_in_generate=True`) or a `torch.LongTensor`.
  2122. If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
  2123. [`~utils.ModelOutput`] types are:
  2124. - [`~generation.GenerateDecoderOnlyOutput`],
  2125. - [`~generation.GenerateBeamDecoderOnlyOutput`]
  2126. If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
  2127. [`~utils.ModelOutput`] types are:
  2128. - [`~generation.GenerateEncoderDecoderOutput`],
  2129. - [`~generation.GenerateBeamEncoderDecoderOutput`]
  2130. """
  2131. # 0. If requested, load an arbitrary generation recipe from the Hub and run it instead
  2132. trust_remote_code = kwargs.pop("trust_remote_code", None)
  2133. if custom_generate is not None and isinstance(custom_generate, str):
  2134. # Get all `generate` arguments in a single variable. Custom functions are responsible for handling them:
  2135. # they receive the same inputs as `generate`, with `model` instead of `self` and excluding the arguments to
  2136. # trigger the custom generation. They can access to methods from `GenerationMixin` through `model`.
  2137. global_keys_to_exclude = {
  2138. "self",
  2139. "kwargs",
  2140. "global_keys_to_exclude",
  2141. "trust_remote_code",
  2142. "custom_generate",
  2143. }
  2144. generate_arguments = {key: value for key, value in locals().items() if key not in global_keys_to_exclude}
  2145. generate_arguments.update(kwargs)
  2146. custom_generate_function = self.load_custom_generate(
  2147. custom_generate, trust_remote_code=trust_remote_code, **kwargs
  2148. )
  2149. return custom_generate_function(model=self, **generate_arguments)
  2150. # 1. Handle kwargs, `generation_config`, validate them and obtain generation mode
  2151. generation_mode_kwargs = self._extract_generation_mode_kwargs(
  2152. custom_generate,
  2153. kwargs,
  2154. synced_gpus,
  2155. assistant_model,
  2156. streamer,
  2157. )
  2158. generation_config, model_kwargs = self._prepare_generation_config(
  2159. generation_config, use_model_defaults, **kwargs
  2160. )
  2161. generation_mode = generation_config.get_generation_mode(assistant_model)
  2162. if isinstance(custom_generate, Callable):
  2163. decoding_method = custom_generate
  2164. else:
  2165. # type() required to access the unbound class-level method
  2166. decoding_method = getattr(type(self), GENERATION_MODES_MAPPING[generation_mode])
  2167. self._validate_model_kwargs(model_kwargs.copy())
  2168. self._validate_generation_mode(generation_mode, generation_config, generation_mode_kwargs)
  2169. # Deprecation-related step: set Hub repo for deprecated strategies.
  2170. # NOTE: This must come after initializing generation_config, since we need it to determine if this is a deprecated mode.
  2171. # It must also be before any preparation steps, since Hub repos expect to be loaded before preparation steps.
  2172. # TODO joao, manuel: remove this in v4.62.0
  2173. if deprecated_mode_repo := self._get_deprecated_gen_repo(generation_mode, trust_remote_code, custom_generate):
  2174. return GenerationMixin.generate(
  2175. self,
  2176. inputs=inputs,
  2177. generation_config=generation_config,
  2178. logits_processor=logits_processor,
  2179. stopping_criteria=stopping_criteria,
  2180. prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
  2181. assistant_model=assistant_model,
  2182. negative_prompt_ids=negative_prompt_ids,
  2183. negative_prompt_attention_mask=negative_prompt_attention_mask,
  2184. use_model_defaults=use_model_defaults,
  2185. custom_generate=deprecated_mode_repo,
  2186. trust_remote_code=trust_remote_code,
  2187. **generation_mode_kwargs,
  2188. **kwargs,
  2189. )
  2190. # 2. Set generation parameters if not already defined
  2191. logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
  2192. stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
  2193. accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
  2194. requires_attention_mask = "encoder_outputs" not in model_kwargs
  2195. kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
  2196. # 3. Define model inputs
  2197. inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
  2198. inputs, generation_config.bos_token_id, model_kwargs
  2199. )
  2200. # Some generation modes (e.g. assisted) need `inputs_tensor` to rerun encoder.forward()
  2201. if "inputs_tensor" in inspect.signature(decoding_method).parameters.keys():
  2202. generation_mode_kwargs["inputs_tensor"] = inputs_tensor
  2203. batch_size = inputs_tensor.shape[0]
  2204. device = inputs_tensor.device
  2205. self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device)
  2206. # decoder-only models must use left-padding for batched generation.
  2207. if not self.config.is_encoder_decoder:
  2208. # If `input_ids` was given, check if the last id in any sequence is `pad_token_id`
  2209. # Note: If using, `inputs_embeds` this check does not work, because we want to be more hands-off.
  2210. if (
  2211. generation_config._pad_token_tensor is not None
  2212. and batch_size > 1
  2213. and len(inputs_tensor.shape) == 2
  2214. and torch.sum(inputs_tensor[:, -1] == generation_config._pad_token_tensor) > 0
  2215. ):
  2216. logger.warning(
  2217. "A decoder-only architecture is being used, but right-padding was detected! For correct "
  2218. "generation results, please set `padding_side='left'` when initializing the tokenizer."
  2219. )
  2220. # 4. Define other model kwargs
  2221. # decoder-only models with inputs_embeds forwarding must use caching (otherwise we can't detect whether we are
  2222. # generating the first new token or not, and we only want to use the embeddings for the first new token)
  2223. if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds":
  2224. generation_config.use_cache = True
  2225. if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask:
  2226. model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
  2227. inputs_tensor, generation_config, model_kwargs
  2228. )
  2229. elif kwargs_has_attention_mask:
  2230. # TODO (joao): generalize this check with other types of inputs
  2231. if model_input_name == "input_ids" and len(model_kwargs["attention_mask"].shape) > 2:
  2232. raise ValueError("`attention_mask` passed to `generate` must be 2D.")
  2233. if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
  2234. # if model is encoder decoder encoder_outputs are created and added to `model_kwargs`
  2235. model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
  2236. inputs_tensor, model_kwargs, model_input_name, generation_config
  2237. )
  2238. # 5. Prepare `input_ids` which will be used for auto-regressive generation
  2239. if self.config.is_encoder_decoder:
  2240. input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
  2241. batch_size=batch_size,
  2242. model_input_name=model_input_name,
  2243. model_kwargs=model_kwargs,
  2244. decoder_start_token_id=generation_config._decoder_start_token_tensor,
  2245. device=inputs_tensor.device,
  2246. )
  2247. else:
  2248. input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
  2249. # Expand inputs depending on the generation mode
  2250. input_ids, model_kwargs = self._expand_inputs_for_generation(
  2251. input_ids=input_ids,
  2252. expand_size=max(generation_config.num_beams, generation_config.num_return_sequences),
  2253. is_encoder_decoder=self.config.is_encoder_decoder,
  2254. **model_kwargs,
  2255. )
  2256. if generation_config.token_healing:
  2257. input_ids = self.heal_tokens(input_ids, generation_mode_kwargs.get("tokenizer"))
  2258. if streamer is not None:
  2259. streamer.put(input_ids.cpu())
  2260. # 6. Prepare `max_length` depending on other stopping criteria.
  2261. input_ids_length = input_ids.shape[1]
  2262. has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
  2263. has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
  2264. generation_config = self._prepare_generated_length(
  2265. generation_config=generation_config,
  2266. has_default_max_length=has_default_max_length,
  2267. has_default_min_length=has_default_min_length,
  2268. model_input_name=model_input_name,
  2269. inputs_tensor=inputs_tensor,
  2270. input_ids_length=input_ids_length,
  2271. )
  2272. # If the model supports `logits_to_keep` in forward(), set it to 1 to avoid computing the whole
  2273. # logit matrix. This can save a lot of memory during the first forward pass. Note that assisted decoding
  2274. # dynamically overrides this value as it can need more than the last token logits
  2275. if self._supports_logits_to_keep() and "logits_to_keep" not in model_kwargs:
  2276. model_kwargs["logits_to_keep"] = 1
  2277. self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
  2278. # 7. Prepare the cache.
  2279. # - `model_kwargs` may be updated in place with a cache as defined by the parameters in `generation_config`.
  2280. # - different models have a different cache name expected by the model (default = "past_key_values")
  2281. # - `max_length`, prepared above, is used to determine the maximum cache length
  2282. max_cache_length = generation_config.max_length - 1
  2283. if (
  2284. inputs_tensor.shape[1] != input_ids_length
  2285. and model_input_name == "inputs_embeds"
  2286. and not self.config.is_encoder_decoder
  2287. ):
  2288. max_cache_length += inputs_tensor.shape[1]
  2289. self._prepare_cache_for_generation(
  2290. generation_config, model_kwargs, generation_mode, batch_size, max_cache_length
  2291. )
  2292. if self.device.type != input_ids.device.type:
  2293. warnings.warn(
  2294. "You are calling .generate() with the `input_ids` being on a device type different"
  2295. f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model"
  2296. f" is on {self.device.type}. You may experience unexpected behaviors or slower generation."
  2297. " Please make sure that you have put `input_ids` to the"
  2298. f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before"
  2299. " running `.generate()`.",
  2300. UserWarning,
  2301. )
  2302. # 8. prepare logits processors and stopping criteria
  2303. prepared_logits_processor = self._get_logits_processor(
  2304. generation_config=generation_config,
  2305. input_ids_seq_length=input_ids_length,
  2306. encoder_input_ids=inputs_tensor,
  2307. prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
  2308. logits_processor=logits_processor,
  2309. device=inputs_tensor.device,
  2310. model_kwargs=model_kwargs,
  2311. negative_prompt_ids=negative_prompt_ids,
  2312. negative_prompt_attention_mask=negative_prompt_attention_mask,
  2313. )
  2314. prepared_stopping_criteria = self._get_stopping_criteria(
  2315. generation_config=generation_config,
  2316. stopping_criteria=stopping_criteria,
  2317. tokenizer=generation_mode_kwargs.get("tokenizer"),
  2318. )
  2319. # Set model_kwargs `use_cache` so we can use it later in forward runs
  2320. model_kwargs["use_cache"] = generation_config.use_cache
  2321. # 9. Call generation mode
  2322. result = decoding_method(
  2323. self,
  2324. input_ids,
  2325. logits_processor=prepared_logits_processor,
  2326. stopping_criteria=prepared_stopping_criteria,
  2327. generation_config=generation_config,
  2328. **generation_mode_kwargs,
  2329. **model_kwargs,
  2330. )
  2331. # Convert to legacy cache format if requested
  2332. if (
  2333. generation_config.return_legacy_cache is True
  2334. and hasattr(result, "past_key_values")
  2335. and getattr(result.past_key_values, "to_legacy_cache") is not None
  2336. ):
  2337. result.past_key_values = result.past_key_values.to_legacy_cache()
  2338. return result
  2339. def _has_unfinished_sequences(self, this_peer_finished: bool, synced_gpus: bool, device: torch.device) -> bool:
  2340. """
  2341. Returns whether there are still unfinished sequences in the device. The existence of unfinished sequences is
  2342. fed through `this_peer_finished`. ZeRO stage 3-friendly.
  2343. """
  2344. if synced_gpus:
  2345. # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
  2346. # The following logic allows an early break if all peers finished generating their sequence
  2347. this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0, device=device)
  2348. # send 0.0 if we finished, 1.0 otherwise
  2349. dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
  2350. # did all peers finish? the reduced sum will be 0.0 then
  2351. if this_peer_finished_flag.item() == 0.0:
  2352. return False
  2353. elif this_peer_finished:
  2354. return False
  2355. return True
  2356. def heal_tokens(
  2357. self, input_ids: torch.LongTensor, tokenizer: Optional["PreTrainedTokenizerBase"] = None
  2358. ) -> torch.LongTensor:
  2359. r"""
  2360. Generates sequences of token ids for models with a language modeling head.
  2361. Parameters:
  2362. input_ids (`torch.LongTensor`): The sequence used as a prompt for the generation.
  2363. tokenizer (`PreTrainedTokenizerBase`, *optional*): The tokenizer used to decode the input ids.
  2364. Return:
  2365. `torch.LongTensor` where each sequence has its tail token replaced with its appropriate extension.
  2366. """
  2367. if tokenizer is None:
  2368. raise ValueError(
  2369. " When generating with token healing, you must pass the model's tokenizer to the `tokenizer` "
  2370. "argument of `generate`."
  2371. )
  2372. bos_token_id, pad_token_id = tokenizer.bos_token_id, tokenizer.pad_token_id
  2373. vocab_trie = ExtensionsTrie(tokenizer.get_vocab())
  2374. generation_config = GenerationConfig(max_new_tokens=1, pad_token_id=pad_token_id)
  2375. # assumption: leading/trailing whitespace is not meaningful, so the prompts are
  2376. # stripped before re-tokenizing to desensitize generation to whitespace artefacts
  2377. prompts = [p.strip() for p in tokenizer.batch_decode(input_ids, skip_special_tokens=True)]
  2378. input_ids = tokenizer(
  2379. prompts,
  2380. return_tensors="pt",
  2381. padding=True,
  2382. ).input_ids.to(input_ids.device)
  2383. # replace bos with pad to not condition healing on it
  2384. input_ids = torch.where(input_ids == bos_token_id, pad_token_id, input_ids)
  2385. # the latter code assumes the input_ids is not empty, input_id has to be checked if contains elements
  2386. if input_ids.numel() == 0:
  2387. return input_ids
  2388. tail_ids = input_ids[:, -1].tolist()
  2389. # tail tokens are used for a prefix search, thus, whitespaces are replaced with
  2390. # their tokenization (e.g. 'Ġ') to enable search for tokens prefixed with a whitespace
  2391. if tokenizer.convert_tokens_to_ids(" ") is not None:
  2392. space_tok = tokenizer.convert_ids_to_tokens(tokenizer.convert_tokens_to_ids(" "))[0]
  2393. tail_toks = (tokenizer.decode(t).replace(" ", space_tok) for t in tail_ids)
  2394. else:
  2395. tail_toks = (tokenizer.decode(t) for t in tail_ids)
  2396. for batch_idx, (tail_id, tail_tok) in enumerate(zip(tail_ids, tail_toks)):
  2397. batch_ids = input_ids[batch_idx]
  2398. if torch.all(batch_ids == pad_token_id).item():
  2399. continue # skip empty sequences (all pad ids)
  2400. # apply bias for alternatives (extensions) to the tail token
  2401. """
  2402. seq_bias key has to be tuple with int so have to use
  2403. tokenizer function to convert str to int
  2404. """
  2405. seq_bias = {
  2406. (tokenizer.convert_tokens_to_ids(alt_tok),): 10.0 for alt_tok in vocab_trie.extensions(prefix=tail_tok)
  2407. }
  2408. if len(seq_bias) == 1:
  2409. continue # skip if there are no token alternatives to heal with
  2410. # slightly favor original token to limit aggressive healing e.g. 'http' -> 'https'
  2411. seq_bias[(tail_id,)] += 1.0
  2412. generation_config.update(sequence_bias=seq_bias)
  2413. trimmed_ids = batch_ids[:-1]
  2414. """
  2415. the latter code assumes trimmed_ids is not empty
  2416. so have to check the its element count
  2417. """
  2418. if trimmed_ids.numel() == 0:
  2419. continue
  2420. # if the prompt is a single (non-pad) token, regenerate from bos
  2421. if len(batch_ids[batch_ids != pad_token_id]) == 1:
  2422. trimmed_ids[-1] = bos_token_id
  2423. input_ids[batch_idx] = self.generate(trimmed_ids.unsqueeze(0), generation_config=generation_config)
  2424. return input_ids
  2425. def _sample(
  2426. self,
  2427. input_ids: torch.LongTensor,
  2428. logits_processor: LogitsProcessorList,
  2429. stopping_criteria: StoppingCriteriaList,
  2430. generation_config: GenerationConfig,
  2431. synced_gpus: bool = False,
  2432. streamer: Optional["BaseStreamer"] = None,
  2433. **model_kwargs,
  2434. ) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
  2435. r"""
  2436. Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and
  2437. can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
  2438. Parameters:
  2439. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  2440. The sequence used as a prompt for the generation.
  2441. logits_processor (`LogitsProcessorList`):
  2442. An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
  2443. used to modify the prediction scores of the language modeling head applied at each generation step.
  2444. stopping_criteria (`StoppingCriteriaList`):
  2445. An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
  2446. used to tell if the generation loop should stop.
  2447. generation_config ([`~generation.GenerationConfig`]):
  2448. The generation configuration to be used as parametrization of the decoding method.
  2449. synced_gpus (`bool`):
  2450. Whether to continue running the while loop until max_length (needed to avoid deadlocking with
  2451. `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
  2452. streamer (`BaseStreamer`, *optional*):
  2453. Streamer object that will be used to stream the generated sequences. Generated tokens are passed
  2454. through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
  2455. model_kwargs:
  2456. Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
  2457. an encoder-decoder model the kwargs should include `encoder_outputs`.
  2458. Return:
  2459. [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or `torch.LongTensor`:
  2460. A `torch.LongTensor` containing the generated tokens (default behaviour) or a
  2461. [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
  2462. `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if
  2463. `model.config.is_encoder_decoder=True`.
  2464. """
  2465. # init values
  2466. pad_token_id = generation_config._pad_token_tensor
  2467. output_attentions = generation_config.output_attentions
  2468. output_hidden_states = generation_config.output_hidden_states
  2469. output_scores = generation_config.output_scores
  2470. output_logits = generation_config.output_logits
  2471. return_dict_in_generate = generation_config.return_dict_in_generate
  2472. has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
  2473. do_sample = generation_config.do_sample
  2474. # init attention / hidden states / scores tuples
  2475. scores = () if (return_dict_in_generate and output_scores) else None
  2476. raw_logits = () if (return_dict_in_generate and output_logits) else None
  2477. decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
  2478. cross_attentions = () if (return_dict_in_generate and output_attentions) else None
  2479. decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
  2480. # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
  2481. if return_dict_in_generate and self.config.is_encoder_decoder:
  2482. encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
  2483. encoder_hidden_states = (
  2484. model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
  2485. )
  2486. # keep track of which sequences are already finished
  2487. batch_size, cur_len = input_ids.shape[:2]
  2488. this_peer_finished = False
  2489. unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
  2490. model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
  2491. model_forward = self.__call__
  2492. compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config)
  2493. if compile_forward:
  2494. os.environ["TOKENIZERS_PARALLELISM"] = "0"
  2495. # If we use FA2 and a static cache, we cannot compile with fullgraph
  2496. if self.config._attn_implementation == "flash_attention_2":
  2497. # only raise warning if the user passed an explicit compile-config
  2498. if generation_config.compile_config is not None and generation_config.compile_config.fullgraph:
  2499. logger.warning_once(
  2500. "When using Flash Attention 2 and a static cache, you cannot use the option `CompileConfig(fullgraph=True)` as "
  2501. "FA2 introduces graph breaks. We overrode the option with `fullgraph=False`."
  2502. )
  2503. generation_config.compile_config.fullgraph = False
  2504. model_forward = self.get_compiled_call(generation_config.compile_config)
  2505. if generation_config.prefill_chunk_size is not None:
  2506. model_kwargs = self._prefill_chunking(input_ids, generation_config, **model_kwargs)
  2507. is_prefill = False
  2508. else:
  2509. is_prefill = True
  2510. while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
  2511. # prepare model inputs
  2512. model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
  2513. if is_prefill:
  2514. outputs = self(**model_inputs, return_dict=True)
  2515. is_prefill = False
  2516. else:
  2517. outputs = model_forward(**model_inputs, return_dict=True)
  2518. # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
  2519. model_kwargs = self._update_model_kwargs_for_generation(
  2520. outputs,
  2521. model_kwargs,
  2522. is_encoder_decoder=self.config.is_encoder_decoder,
  2523. )
  2524. if synced_gpus and this_peer_finished:
  2525. continue
  2526. # Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
  2527. # (the clone itself is always small)
  2528. next_token_logits = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device)
  2529. # pre-process distribution
  2530. next_token_scores = logits_processor(input_ids, next_token_logits)
  2531. # Store scores, attentions and hidden_states when required
  2532. if return_dict_in_generate:
  2533. if output_scores:
  2534. scores += (next_token_scores,)
  2535. if output_logits:
  2536. raw_logits += (next_token_logits,)
  2537. if output_attentions:
  2538. decoder_attentions += (
  2539. (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
  2540. )
  2541. if self.config.is_encoder_decoder:
  2542. cross_attentions += (outputs.cross_attentions,)
  2543. if output_hidden_states:
  2544. decoder_hidden_states += (
  2545. (outputs.decoder_hidden_states,)
  2546. if self.config.is_encoder_decoder
  2547. else (outputs.hidden_states,)
  2548. )
  2549. # token selection
  2550. if do_sample:
  2551. probs = nn.functional.softmax(next_token_scores, dim=-1)
  2552. # TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution
  2553. next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
  2554. else:
  2555. next_tokens = torch.argmax(next_token_scores, dim=-1)
  2556. # finished sentences should have their next token be a padding token
  2557. if has_eos_stopping_criteria:
  2558. next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
  2559. # update generated ids, model inputs, and length for next step
  2560. input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
  2561. if streamer is not None:
  2562. streamer.put(next_tokens.cpu())
  2563. unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
  2564. this_peer_finished = unfinished_sequences.max() == 0
  2565. cur_len += 1
  2566. # This is needed to properly delete outputs.logits which may be very large for first iteration
  2567. # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
  2568. del outputs
  2569. if streamer is not None:
  2570. streamer.end()
  2571. if return_dict_in_generate:
  2572. if self.config.is_encoder_decoder:
  2573. return GenerateEncoderDecoderOutput(
  2574. sequences=input_ids,
  2575. scores=scores,
  2576. logits=raw_logits,
  2577. encoder_attentions=encoder_attentions,
  2578. encoder_hidden_states=encoder_hidden_states,
  2579. decoder_attentions=decoder_attentions,
  2580. cross_attentions=cross_attentions,
  2581. decoder_hidden_states=decoder_hidden_states,
  2582. past_key_values=model_kwargs.get("past_key_values"),
  2583. )
  2584. else:
  2585. return GenerateDecoderOnlyOutput(
  2586. sequences=input_ids,
  2587. scores=scores,
  2588. logits=raw_logits,
  2589. attentions=decoder_attentions,
  2590. hidden_states=decoder_hidden_states,
  2591. past_key_values=model_kwargs.get("past_key_values"),
  2592. )
  2593. else:
  2594. return input_ids
  2595. @staticmethod
  2596. def _flatten_beam_dim(tensor: torch.Tensor) -> torch.Tensor:
  2597. """[batch_size, num_beams, ...] -> [batch_size * num_beams, ...]"""
  2598. shape = list(tensor.shape)
  2599. return torch.reshape(tensor, [shape[0] * shape[1]] + shape[2:])
  2600. @staticmethod
  2601. def _unflatten_beam_dim(tensor: torch.Tensor, batch_size: int, num_beams: int) -> torch.Tensor:
  2602. """[batch_size * num_beams, ...] -> [batch_size, num_beams, ...]"""
  2603. shape = list(tensor.shape)
  2604. return torch.reshape(tensor, [batch_size, num_beams] + shape[1:])
  2605. @staticmethod
  2606. def _gather_beams(tensor: torch.Tensor, beam_indices: torch.Tensor) -> torch.Tensor:
  2607. """
  2608. Gathers the beam slices indexed by beam_indices into new beam array.
  2609. Args:
  2610. tensor (`torch.Tensor`): A tensor containing data to be gathered. The tensor is a 2D or a 3D tensor
  2611. with the two first dimensions depicting the batch and the beam dimensions.
  2612. beam_indices (`torch.Tensor` of shape `(batch_size, num_beams_to_select)`): The indices of the beams to
  2613. select .
  2614. Returns:
  2615. A tensor with the selected beams
  2616. """
  2617. # `take_along_dim` requires its indices arg to have the same number of dims as `input`
  2618. while len(beam_indices.shape) < len(tensor.shape):
  2619. beam_indices = beam_indices.unsqueeze(-1)
  2620. gathered_tensor = torch.take_along_dim(input=tensor, indices=beam_indices, dim=1)
  2621. return gathered_tensor
  2622. @staticmethod
  2623. def _check_early_stop_heuristic(
  2624. is_early_stop_heuristic_unsatisfied: torch.Tensor,
  2625. running_beam_scores: torch.Tensor,
  2626. beam_scores: torch.Tensor,
  2627. is_sent_finished: torch.Tensor,
  2628. cur_len: int,
  2629. max_length: int,
  2630. decoder_prompt_len: int,
  2631. early_stopping: Union[bool, str],
  2632. length_penalty: float,
  2633. ):
  2634. """
  2635. Determine whether early stopping is possible by checking if the best possible score of running beams
  2636. could still improve upon the finished ones.
  2637. Mechanism:
  2638. - Without a length penalty, beam scores typically decrease as more tokens are generated.
  2639. So, if the *best possible* score from any running beam is already worse than the *worst* finished beam,
  2640. we can safely stop early.
  2641. - With a length penalty, scores may increase with longer sequences. In this case, we use heuristics
  2642. to estimate the best possible score — though this estimate may not always be correct — and stop
  2643. if no further improvement seems likely.
  2644. We apply different heuristics depending on the value of `early_stopping`:
  2645. 1. `early_stopping == False`:
  2646. -> Use a heuristic that assumes the best score comes from the current length minus the decoder prompt length.
  2647. -> See detailed discussion: https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565
  2648. 2. `early_stopping == "never"`:
  2649. -> Estimate the best score using either `max_length` or `cur_len`, depending on the sign of `length_penalty`.
  2650. -> A positive length penalty favors longer sequences, so we use `max_length` in that case.
  2651. NOTE: the canonical beam search implementation can be replicated with `early_stopping="never"` and
  2652. `length_penalty=0.0`, which are NOT the default flags. The default behavior was empirically found to produce
  2653. better sequences (prior to 2022), and changing it is BC breaking.
  2654. """
  2655. if early_stopping == "never" and length_penalty > 0.0:
  2656. best_hypothetical_length = max_length - decoder_prompt_len
  2657. else:
  2658. best_hypothetical_length = cur_len - decoder_prompt_len
  2659. best_possible_running_score = running_beam_scores[:, :1] / (best_hypothetical_length**length_penalty)
  2660. worst_finished_score = torch.where(is_sent_finished, torch.min(beam_scores, dim=1, keepdim=True)[0], -1.0e9)
  2661. return is_early_stop_heuristic_unsatisfied & torch.any(
  2662. best_possible_running_score > worst_finished_score, dim=-1, keepdim=True
  2663. )
  2664. @staticmethod
  2665. def _beam_search_has_unfinished_sequences(
  2666. is_early_stop_heuristic_unsatisfied: torch.Tensor,
  2667. is_sent_finished: torch.Tensor,
  2668. next_token_hits_stopping_criteria: torch.Tensor,
  2669. early_stopping: Union[bool, str],
  2670. ):
  2671. """
  2672. Beam Search stopping condition -- halts the generation loop if any of these conditions becomes False
  2673. """
  2674. # a. Can the open beams improve the top completed scores?
  2675. improvement_possible = torch.any(is_early_stop_heuristic_unsatisfied)
  2676. # b. Is there still a beam without fully completed sequences? This is only relevant if early_stopping is
  2677. # enabled, where we want to finish as soon as all beams have a completed sequence.
  2678. exists_open_beam = ~(torch.all(is_sent_finished) & (early_stopping is True))
  2679. # c. Have we hit a stopping criteria with all running sequences and have no way to continue? e.g. we have
  2680. # reached `max_length``
  2681. valid_continuations = ~torch.all(next_token_hits_stopping_criteria)
  2682. return improvement_possible & exists_open_beam & valid_continuations
  2683. def _get_top_k_continuations(
  2684. self,
  2685. accumulated_log_probs: torch.Tensor,
  2686. running_sequences: torch.Tensor,
  2687. running_beam_indices: torch.Tensor,
  2688. cur_len: int,
  2689. decoder_prompt_len: int,
  2690. do_sample: bool,
  2691. beams_to_keep: int,
  2692. num_beams: int,
  2693. vocab_size: int,
  2694. batch_size: int,
  2695. ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  2696. """
  2697. Get top-K continuations given the accumulated log probs on the next token.
  2698. A few notes to understand what's going on:
  2699. 1. Each item in batch has `num_beams` * `vocab_size` candidate continuations. For each item, get the
  2700. top K [K = (number of EOS tokens + 1) * `num_beams`] candidates with the highest accumulated
  2701. log-probabilities, or sample them without replacement using the accumulated scores
  2702. 2. We gather the top K (as opposed to `num_beams`, or any number lower than K) here so that we have at
  2703. least `num_beams` sequences remaining to continue the live beam search.
  2704. 3. Note that other stopping criteria might result in impossible to continue beams, i.e. all continuations
  2705. selected in this step hit the stopping criteria.
  2706. """
  2707. # TODO (joao): This function should take an optional beam scorer function, to manipulate the scores after
  2708. # token selection. The function should be an argument exposed, so that custom scoring functions can be
  2709. # defined.
  2710. # Gather the top K scores from _all_ beams.
  2711. if do_sample:
  2712. topk_indices = torch.multinomial(
  2713. nn.functional.softmax(accumulated_log_probs, dim=-1), num_samples=beams_to_keep
  2714. )
  2715. topk_log_probs = torch.gather(input=accumulated_log_probs, dim=1, index=topk_indices)
  2716. else:
  2717. topk_log_probs, topk_indices = torch.topk(accumulated_log_probs, k=beams_to_keep)
  2718. # Gather K top beams, recover the beam index by floor division and token id by modulo division
  2719. topk_current_beam_indices = topk_indices // vocab_size
  2720. topk_running_beam_indices = self._gather_beams(running_beam_indices, topk_current_beam_indices)
  2721. topk_running_sequences = self._gather_beams(running_sequences, topk_current_beam_indices)
  2722. topk_ids = topk_indices % vocab_size
  2723. # Update sequences for the K top-k new sequences.
  2724. topk_running_sequences[:, :, cur_len] = topk_ids
  2725. # we want to store the beam indices with batch information -> real beam index = beam index % num beams
  2726. batch_offset = torch.arange(batch_size, device=topk_ids.device).view(-1, 1) * num_beams
  2727. batch_modified_indices = topk_current_beam_indices + batch_offset
  2728. topk_running_beam_indices[:, :, cur_len - decoder_prompt_len] = batch_modified_indices
  2729. return topk_log_probs, topk_running_sequences, topk_running_beam_indices
  2730. def _get_running_beams_for_next_iteration(
  2731. self,
  2732. topk_log_probs: torch.Tensor,
  2733. topk_running_sequences: torch.Tensor,
  2734. topk_running_beam_indices: torch.Tensor,
  2735. next_token_hits_stopping_criteria: torch.Tensor,
  2736. num_beams: int,
  2737. ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  2738. """
  2739. Given the top-K continuations, their scores, and whether they hit a stopping criteria, select the
  2740. best non-finished beams to continue beam search in the next iteration.
  2741. """
  2742. # To prevent these just finished sequences from being used in subsequent iterations, set their log probs
  2743. # to a very large negative value
  2744. topk_running_log_probs = topk_log_probs + next_token_hits_stopping_criteria.to(torch.float32) * -1.0e9
  2745. next_topk_indices = torch.topk(topk_running_log_probs, k=num_beams)[1]
  2746. running_sequences = self._gather_beams(topk_running_sequences, next_topk_indices)
  2747. running_beam_scores = self._gather_beams(topk_running_log_probs, next_topk_indices)
  2748. running_beam_indices = self._gather_beams(topk_running_beam_indices, next_topk_indices)
  2749. return running_sequences, running_beam_scores, running_beam_indices
  2750. def _update_finished_beams(
  2751. self,
  2752. sequences: torch.Tensor,
  2753. topk_running_sequences: torch.Tensor,
  2754. beam_scores: torch.Tensor,
  2755. topk_log_probs: torch.Tensor,
  2756. beam_indices: torch.Tensor,
  2757. topk_running_beam_indices: torch.Tensor,
  2758. is_early_stop_heuristic_unsatisfied: torch.Tensor,
  2759. is_sent_finished: torch.Tensor,
  2760. next_token_hits_stopping_criteria: torch.Tensor,
  2761. top_num_beam_mask: torch.Tensor,
  2762. num_beams: int,
  2763. cur_len: int,
  2764. decoder_prompt_len: int,
  2765. length_penalty: float,
  2766. early_stopping: Union[bool, str],
  2767. ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
  2768. """
  2769. Updates the finished beams if (and only if) there are new completed sequences that have a higher score than
  2770. the current finished sequences.
  2771. """
  2772. # Only the top `num_beam` sequences can be considered for the final returned sequences. Remember: the
  2773. # remaining sequences only exist as a backup to ensure that we have at least `num_beams` sequences to
  2774. # continue.
  2775. did_top_num_beams_just_finished = next_token_hits_stopping_criteria & top_num_beam_mask[None, :]
  2776. # Further process topk logits for the finished beams
  2777. # - add length penalty
  2778. topk_log_probs = topk_log_probs / ((cur_len + 1 - decoder_prompt_len) ** length_penalty)
  2779. # - make sure no scores can be added anymore if beam is full and early stopping is on
  2780. beams_in_batch_are_full = torch.all(is_sent_finished, axis=-1, keepdims=True) & (early_stopping is True)
  2781. topk_log_probs += beams_in_batch_are_full.to(torch.float32) * -1.0e9
  2782. # - make sure no scores can be added anymore if improvement is not possible
  2783. topk_log_probs += (~is_early_stop_heuristic_unsatisfied).to(torch.float32) * -1.0e9
  2784. # - make sure still running sequences cannot be chosen as finalized beam
  2785. topk_log_probs += (~did_top_num_beams_just_finished) * -1.0e9
  2786. # Get finalized `num_beam` sequences for the next generation step -- combine the previous finalized
  2787. # data with the new finalized sequences (if any, non-finalized sequences have a very large negative score
  2788. # in this step), and keep the best `num_beams` sequences.
  2789. merged_sequences = torch.cat((sequences, topk_running_sequences), dim=1)
  2790. merged_scores = torch.cat((beam_scores, topk_log_probs), dim=1)
  2791. merged_beam_indices = torch.cat((beam_indices, topk_running_beam_indices), dim=1)
  2792. merged_is_sent_finished = torch.cat((is_sent_finished, did_top_num_beams_just_finished), dim=1)
  2793. topk_merged_indices = torch.topk(merged_scores, k=num_beams)[1]
  2794. sequences = self._gather_beams(merged_sequences, topk_merged_indices)
  2795. beam_scores = self._gather_beams(merged_scores, topk_merged_indices)
  2796. beam_indices = self._gather_beams(merged_beam_indices, topk_merged_indices)
  2797. is_sent_finished = self._gather_beams(merged_is_sent_finished, topk_merged_indices)
  2798. return sequences, beam_scores, beam_indices, is_sent_finished
  2799. # end of auxiliary functions for beam search
  2800. def _beam_search(
  2801. self,
  2802. input_ids: torch.LongTensor,
  2803. logits_processor: LogitsProcessorList,
  2804. stopping_criteria: StoppingCriteriaList,
  2805. generation_config: GenerationConfig,
  2806. synced_gpus: bool = False,
  2807. **model_kwargs,
  2808. ) -> Union[GenerateBeamOutput, torch.LongTensor]:
  2809. r"""
  2810. Generates sequences of token ids for models with a language modeling head using **beam search decoding** and
  2811. can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
  2812. If it's the first time you're diving into Beam Search, we recommend you read the following blog post:
  2813. https://huggingface.co/blog/how-to-generate (especially the beam search section).
  2814. You can recompute the sequence scores from the individual scores using the `compute_transition_scores` function
  2815. (https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationMixin.compute_transition_scores)
  2816. Parameters:
  2817. input_ids (`torch.LongTensor` of shape `(batch_size*num_beams, sequence_length)`):
  2818. The sequence used as a prompt for the generation.
  2819. logits_processor (`LogitsProcessorList`):
  2820. An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
  2821. used to modify the prediction scores of the language modeling head applied at each generation step.
  2822. stopping_criteria (`StoppingCriteriaList`:
  2823. An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
  2824. used to tell if the generation loop should stop.
  2825. generation_config ([`~generation.GenerationConfig`]):
  2826. The generation configuration to be used as parametrization of the decoding method.
  2827. synced_gpus (`bool`):
  2828. Whether to continue running the while loop until max_length (needed to avoid deadlocking with
  2829. `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
  2830. model_kwargs:
  2831. Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
  2832. an encoder-decoder model the kwargs should include `encoder_outputs`.
  2833. Return:
  2834. [`generation.GenerateBeamDecoderOnlyOutput`], [`~generation.GenerateBeamEncoderDecoderOutput`] or
  2835. `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
  2836. [`~generation.GenerateBeamDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
  2837. `return_dict_in_generate=True` or a [`~generation.GenerateBeamEncoderDecoderOutput`] if
  2838. `model.config.is_encoder_decoder=True`.
  2839. """
  2840. # 1. init beam_search values
  2841. pad_token_id = generation_config._pad_token_tensor
  2842. eos_token_id = generation_config._eos_token_tensor
  2843. output_attentions = generation_config.output_attentions
  2844. output_hidden_states = generation_config.output_hidden_states
  2845. output_scores = generation_config.output_scores
  2846. output_logits = generation_config.output_logits
  2847. return_dict_in_generate = generation_config.return_dict_in_generate
  2848. do_sample = generation_config.do_sample
  2849. early_stopping = generation_config.early_stopping
  2850. length_penalty = generation_config.length_penalty
  2851. max_length = generation_config.max_length
  2852. num_beams = generation_config.num_beams
  2853. num_return_sequences = generation_config.num_return_sequences
  2854. batch_size_unflattened, cur_len = input_ids.shape[:2]
  2855. batch_size = batch_size_unflattened // num_beams
  2856. # TODO (joao): standardize special cases
  2857. if self.__class__.__name__ == "MoshiDepthDecoder":
  2858. vocab_size = self.config.audio_vocab_size
  2859. elif self.__class__.__name__ == "ImageGPTForCausalImageModeling":
  2860. vocab_size = self.get_output_embeddings().out_features
  2861. elif self.__class__.__name__ == "BarkSemanticModel":
  2862. vocab_size = self.config.output_vocab_size
  2863. else:
  2864. vocab_size = self.config.get_text_config().vocab_size
  2865. decoder_prompt_len = cur_len
  2866. this_peer_finished = False
  2867. # At each beam search step, we want to keep top K [K = (number of EOS tokens + 1) * `num_beams`] candidates
  2868. # with the highest log-probabilities, or sample K continuations without replacement. We gather the top K
  2869. # (as opposed to `num_beams`, or any number lower than K) so that we have at least `num_beams` sequences
  2870. # non-finished to continue the live beam search, in case the top `num_beams` all select an EOS token.
  2871. n_eos_tokens = eos_token_id.shape[0] if eos_token_id is not None else 0
  2872. beams_to_keep = max(2, 1 + n_eos_tokens) * num_beams
  2873. top_num_beam_mask = torch.cat(
  2874. (torch.ones((num_beams), dtype=torch.bool), torch.zeros((beams_to_keep - num_beams), dtype=torch.bool)),
  2875. dim=0,
  2876. ).to(input_ids.device)
  2877. model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
  2878. # (joao) feature lost in the refactor. Probably won't implement, hurts readability with minimal gains (there
  2879. # are newer low-memory alternatives like the offloaded cache)
  2880. sequential = generation_config.low_memory
  2881. if sequential:
  2882. raise ValueError(
  2883. "`low_memory=True` is not supported after the beam search refactor. Please check the discussion in "
  2884. "#35802 *after the PR got merged*, and add a comment there if your questions are not yet answered."
  2885. )
  2886. # 2. init output tuples
  2887. all_scores = () if (return_dict_in_generate and output_scores) else None
  2888. raw_logits = () if (return_dict_in_generate and output_logits) else None
  2889. beam_indices = () if (return_dict_in_generate and output_logits) else None
  2890. decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
  2891. cross_attentions = () if (return_dict_in_generate and output_attentions) else None
  2892. decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
  2893. # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
  2894. if return_dict_in_generate and self.config.is_encoder_decoder:
  2895. encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
  2896. encoder_hidden_states = (
  2897. model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
  2898. )
  2899. # 3. init running tensors and static-shaped placeholders
  2900. # per batch, beam-item holding current token in loop and completed sequences
  2901. output_fill_value = pad_token_id or eos_token_id[0] if eos_token_id is not None else -1
  2902. running_sequences = torch.full(
  2903. (batch_size, num_beams, max_length),
  2904. fill_value=output_fill_value,
  2905. dtype=torch.int64,
  2906. device=input_ids.device,
  2907. )
  2908. running_sequences[:, :, :cur_len] = self._unflatten_beam_dim(input_ids, batch_size, num_beams)
  2909. sequences = running_sequences.detach().clone()
  2910. # per batch, beam-item score, logprobs
  2911. # initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens
  2912. # of the first beam are considered to avoid sampling the exact same tokens across all beams.
  2913. running_beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
  2914. running_beam_scores[:, 1:] = -1e9
  2915. beam_scores = torch.full((batch_size, num_beams), fill_value=-1e9, dtype=torch.float, device=input_ids.device)
  2916. # per batch, beam-item state bit indicating if sentence has finished.
  2917. is_sent_finished = torch.zeros((batch_size, num_beams), dtype=torch.bool, device=input_ids.device)
  2918. # per batch state bit indicating if there is a possibility to improve the best finished sentence.
  2919. is_early_stop_heuristic_unsatisfied = torch.ones((batch_size, 1), dtype=torch.bool, device=input_ids.device)
  2920. # per batch, beam-item state bit indicating if there are valid continuations.
  2921. next_token_hits_stopping_criteria = torch.zeros(
  2922. (batch_size, num_beams), dtype=torch.bool, device=input_ids.device
  2923. )
  2924. # per batch selected beam indices
  2925. running_beam_indices = torch.full(
  2926. (batch_size, num_beams, max_length - cur_len), fill_value=-1, dtype=torch.int32, device=input_ids.device
  2927. )
  2928. beam_indices = running_beam_indices.detach().clone()
  2929. # 4. run the generation loop
  2930. while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
  2931. # a. Forward current tokens, obtain the logits
  2932. flat_running_sequences = self._flatten_beam_dim(running_sequences[:, :, :cur_len])
  2933. model_inputs = self.prepare_inputs_for_generation(flat_running_sequences, **model_kwargs)
  2934. model_outputs = self(**model_inputs, return_dict=True)
  2935. # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
  2936. model_kwargs = self._update_model_kwargs_for_generation(
  2937. model_outputs,
  2938. model_kwargs,
  2939. is_encoder_decoder=self.config.is_encoder_decoder,
  2940. )
  2941. if synced_gpus and this_peer_finished:
  2942. continue
  2943. # Copy is needed to avoid keeping a hanging ref
  2944. logits = model_outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device)
  2945. # b. Compute log probs -- get log probabilities from logits, process logits with processors (*e.g.*
  2946. # `temperature`, ...), and add new logprobs to existing running logprobs scores.
  2947. log_probs = nn.functional.log_softmax(logits, dim=-1)
  2948. log_probs = logits_processor(flat_running_sequences, log_probs)
  2949. # Store logits, attentions and hidden_states when required
  2950. if return_dict_in_generate:
  2951. if output_logits:
  2952. raw_logits += (logits.clone(),)
  2953. if return_dict_in_generate and output_scores:
  2954. all_scores += (log_probs.clone(),)
  2955. if output_attentions:
  2956. decoder_attentions += (
  2957. (model_outputs.decoder_attentions,)
  2958. if self.config.is_encoder_decoder
  2959. else (model_outputs.attentions,)
  2960. )
  2961. if self.config.is_encoder_decoder:
  2962. cross_attentions += (model_outputs.cross_attentions,)
  2963. if output_hidden_states:
  2964. decoder_hidden_states += (
  2965. (model_outputs.decoder_hidden_states,)
  2966. if self.config.is_encoder_decoder
  2967. else (model_outputs.hidden_states,)
  2968. )
  2969. # This is needed to properly delete logits which may be very large for first iteration
  2970. # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
  2971. del model_outputs
  2972. log_probs = self._unflatten_beam_dim(log_probs, batch_size, num_beams)
  2973. log_probs = log_probs + running_beam_scores[:, :, None]
  2974. log_probs = torch.reshape(log_probs, (batch_size, num_beams * vocab_size))
  2975. # c. Retrieve top-K continuations, i.e. select the next token (greedy or sampling) and then keep the best
  2976. # continuations among all beams based on the accumulated scores.
  2977. topk_log_probs, topk_running_sequences, topk_running_beam_indices = self._get_top_k_continuations(
  2978. accumulated_log_probs=log_probs,
  2979. running_sequences=running_sequences,
  2980. running_beam_indices=running_beam_indices,
  2981. cur_len=cur_len,
  2982. decoder_prompt_len=decoder_prompt_len,
  2983. do_sample=do_sample,
  2984. beams_to_keep=beams_to_keep,
  2985. num_beams=num_beams,
  2986. vocab_size=vocab_size,
  2987. batch_size=batch_size,
  2988. )
  2989. # d. Check which running sequences have finished
  2990. next_token_hits_stopping_criteria = stopping_criteria(
  2991. self._flatten_beam_dim(topk_running_sequences[:, :, : cur_len + 1]), # remove unfilled token indexes
  2992. all_scores,
  2993. )
  2994. next_token_hits_stopping_criteria = self._unflatten_beam_dim(
  2995. next_token_hits_stopping_criteria, batch_size, beams_to_keep
  2996. )
  2997. # e. Get the non-finished running `num_beams` sequences for the next generation step
  2998. running_sequences, running_beam_scores, running_beam_indices = self._get_running_beams_for_next_iteration(
  2999. topk_log_probs=topk_log_probs,
  3000. topk_running_sequences=topk_running_sequences,
  3001. topk_running_beam_indices=topk_running_beam_indices,
  3002. next_token_hits_stopping_criteria=next_token_hits_stopping_criteria,
  3003. num_beams=num_beams,
  3004. )
  3005. # f. Update the completed beams if a new high score in a finished sequence is found
  3006. sequences, beam_scores, beam_indices, is_sent_finished = self._update_finished_beams(
  3007. sequences=sequences,
  3008. topk_running_sequences=topk_running_sequences,
  3009. beam_scores=beam_scores,
  3010. topk_log_probs=topk_log_probs,
  3011. beam_indices=beam_indices,
  3012. topk_running_beam_indices=topk_running_beam_indices,
  3013. is_early_stop_heuristic_unsatisfied=is_early_stop_heuristic_unsatisfied,
  3014. is_sent_finished=is_sent_finished,
  3015. next_token_hits_stopping_criteria=next_token_hits_stopping_criteria,
  3016. top_num_beam_mask=top_num_beam_mask,
  3017. num_beams=num_beams,
  3018. cur_len=cur_len,
  3019. decoder_prompt_len=decoder_prompt_len,
  3020. length_penalty=length_penalty,
  3021. early_stopping=early_stopping,
  3022. )
  3023. # g. Prepare remaining data for the next iteration, including computing the stopping condition for
  3024. # beam search as a whole (as opposed to individual beams, i.e. `stopping_criteria`)
  3025. # pluck the cache from the beam indices that will be used in the next iteration
  3026. # NOTE: we need to check if `self._reorder_cache` exists for special models like RAG, RecurrentGemma etc.
  3027. if model_kwargs.get("past_key_values", None) is not None:
  3028. beam_idx = self._flatten_beam_dim(running_beam_indices[..., cur_len - decoder_prompt_len])
  3029. if hasattr(self, "_reorder_cache"):
  3030. model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx)
  3031. else:
  3032. model_kwargs["past_key_values"].reorder_cache(beam_idx)
  3033. cur_len = cur_len + 1
  3034. is_early_stop_heuristic_unsatisfied = self._check_early_stop_heuristic(
  3035. is_early_stop_heuristic_unsatisfied=is_early_stop_heuristic_unsatisfied,
  3036. running_beam_scores=running_beam_scores,
  3037. beam_scores=beam_scores,
  3038. is_sent_finished=is_sent_finished,
  3039. cur_len=cur_len,
  3040. max_length=max_length,
  3041. decoder_prompt_len=decoder_prompt_len,
  3042. early_stopping=early_stopping,
  3043. length_penalty=length_penalty,
  3044. )
  3045. this_peer_finished = not self._beam_search_has_unfinished_sequences(
  3046. is_early_stop_heuristic_unsatisfied,
  3047. is_sent_finished,
  3048. next_token_hits_stopping_criteria,
  3049. early_stopping,
  3050. )
  3051. # 5. prepare outputs
  3052. # Take best beams for each batch (the score is sorted in descending order)
  3053. sequences = self._flatten_beam_dim(sequences[:, :num_return_sequences, :])
  3054. beam_scores = self._flatten_beam_dim(beam_scores[:, :num_return_sequences])
  3055. beam_indices = self._flatten_beam_dim(beam_indices[:, :num_return_sequences, :])
  3056. # Crop the static-shaped tensors to the actual size.
  3057. # `beam_indices` is initialized with -1s, and is updated with the beam index of the generated token at each
  3058. # step. We can use it to detect the generated length, which may be != `cur_len` (e.g. selected beam is from a
  3059. # previous decoding iteration)
  3060. max_generated_length = ((beam_indices + 1).bool()).sum(dim=1).max()
  3061. output_length = decoder_prompt_len + max_generated_length
  3062. sequences = sequences[:, :output_length]
  3063. beam_indices = beam_indices[:, :max_generated_length]
  3064. if return_dict_in_generate:
  3065. if not output_scores:
  3066. beam_scores = None
  3067. if self.config.is_encoder_decoder:
  3068. return GenerateBeamEncoderDecoderOutput(
  3069. sequences=sequences,
  3070. sequences_scores=beam_scores,
  3071. scores=all_scores,
  3072. logits=raw_logits,
  3073. beam_indices=beam_indices,
  3074. encoder_attentions=encoder_attentions,
  3075. encoder_hidden_states=encoder_hidden_states,
  3076. decoder_attentions=decoder_attentions,
  3077. cross_attentions=cross_attentions,
  3078. decoder_hidden_states=decoder_hidden_states,
  3079. past_key_values=model_kwargs.get("past_key_values"),
  3080. )
  3081. else:
  3082. return GenerateBeamDecoderOnlyOutput(
  3083. sequences=sequences,
  3084. sequences_scores=beam_scores,
  3085. scores=all_scores,
  3086. logits=raw_logits,
  3087. beam_indices=beam_indices,
  3088. attentions=decoder_attentions,
  3089. hidden_states=decoder_hidden_states,
  3090. past_key_values=model_kwargs.get("past_key_values"),
  3091. )
  3092. else:
  3093. return sequences
  3094. def _assisted_decoding(
  3095. self,
  3096. input_ids: torch.LongTensor,
  3097. logits_processor: LogitsProcessorList,
  3098. stopping_criteria: StoppingCriteriaList,
  3099. generation_config: GenerationConfig,
  3100. synced_gpus: bool = False,
  3101. streamer: Optional["BaseStreamer"] = None,
  3102. inputs_tensor: Optional[torch.FloatTensor] = None,
  3103. assistant_model: Optional["PreTrainedModel"] = None,
  3104. assistant_tokenizer: Optional["PreTrainedTokenizerBase"] = None,
  3105. tokenizer: Optional["PreTrainedTokenizerBase"] = None,
  3106. **model_kwargs,
  3107. ) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
  3108. r"""
  3109. Generates sequences of token ids for models with a language modeling head using **greedy decoding** or
  3110. **sample** (depending on `do_sample`), assisted by candidate sequences. Assisted generation is an example of a
  3111. candidate decoding strategy. Can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text
  3112. models.
  3113. Parameters:
  3114. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  3115. The sequence used as a prompt for the generation.
  3116. logits_processor (`LogitsProcessorList`):
  3117. An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
  3118. used to modify the prediction scores of the language modeling head applied at each generation step.
  3119. stopping_criteria (`StoppingCriteriaList`):
  3120. An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
  3121. used to tell if the generation loop should stop.
  3122. generation_config ([`~generation.GenerationConfig`]):
  3123. The generation configuration to be used as parametrization of the decoding method.
  3124. synced_gpus (`bool`):
  3125. Whether to continue running the while loop until max_length (needed to avoid deadlocking with
  3126. `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
  3127. streamer (`BaseStreamer`, *optional*):
  3128. Streamer object that will be used to stream the generated sequences. Generated tokens are passed
  3129. through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
  3130. inputs_tensor (`torch.FloatTensor`, *optional*):
  3131. The input tensor for generation. For decoder models, usually `input_ids`. For encoder-decoder models,
  3132. the tensor that produced `model_kwargs["encoder_outputs"]`.
  3133. assistant_model (`PreTrainedModel`, *optional*):
  3134. The model used to assist the generation process. If not provided, the main model will be used.
  3135. assistant_tokenizer (`PreTrainedTokenizerBase`, *optional*):
  3136. The tokenizer used for the assistant model. If not provided, the token space is assumed to be the same.
  3137. tokenizer (`PreTrainedTokenizerBase`, *optional*):
  3138. The tokenizer used for the main model. If not provided, the token space is assumed to be the same.
  3139. model_kwargs:
  3140. Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
  3141. If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
  3142. Return:
  3143. [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or
  3144. `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
  3145. [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
  3146. `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if
  3147. `model.config.is_encoder_decoder=True`.
  3148. """
  3149. # The cache must be dynamic for assisted generation, and the check must happen AFTER preparing cache
  3150. if not model_kwargs["use_cache"]:
  3151. raise ValueError("assisted generate requires `use_cache=True`")
  3152. if generation_config.cache_implementation in ["static", "hybrid", "sliding_window"] or (
  3153. "past_key_values" in model_kwargs
  3154. and hasattr(model_kwargs["past_key_values"], "layers")
  3155. and any(getattr(l, "is_compileable", False) for l in model_kwargs["past_key_values"].layers)
  3156. ):
  3157. raise ValueError("assisted generate is not supported with Static cache classes`")
  3158. # Get the candidate generator, given the parameterization
  3159. candidate_generator = self._get_candidate_generator(
  3160. generation_config=generation_config,
  3161. input_ids=input_ids,
  3162. inputs_tensor=inputs_tensor,
  3163. assistant_model=assistant_model,
  3164. logits_processor=logits_processor,
  3165. target_tokenizer=tokenizer,
  3166. assistant_tokenizer=assistant_tokenizer,
  3167. model_kwargs=model_kwargs,
  3168. )
  3169. # init values
  3170. do_sample = generation_config.do_sample
  3171. output_attentions = generation_config.output_attentions
  3172. output_hidden_states = generation_config.output_hidden_states
  3173. output_scores = generation_config.output_scores
  3174. output_logits = generation_config.output_logits
  3175. return_dict_in_generate = generation_config.return_dict_in_generate
  3176. # init attention / hidden states / scores tuples
  3177. scores = () if (return_dict_in_generate and output_scores) else None
  3178. raw_logits = () if (return_dict_in_generate and output_logits) else None
  3179. decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
  3180. cross_attentions = () if (return_dict_in_generate and output_attentions) else None
  3181. decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
  3182. # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
  3183. if return_dict_in_generate and self.config.is_encoder_decoder:
  3184. encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
  3185. encoder_hidden_states = (
  3186. model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
  3187. )
  3188. # keep track of which sequences are already finished
  3189. batch_size, cur_len = input_ids.shape[:2]
  3190. if batch_size > 1:
  3191. raise ValueError("assisted generate is only supported for batch_size = 1")
  3192. unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
  3193. model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
  3194. this_peer_finished = False
  3195. is_first_iteration = True # to preserve the same API in the output as other generation methods
  3196. while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
  3197. cur_len = input_ids.shape[1]
  3198. # 1. Fetch candidate sequences from a `CandidateGenerator` and move to the correct device
  3199. candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids)
  3200. candidate_input_ids = candidate_input_ids.to(self.device)
  3201. if candidate_logits is not None:
  3202. candidate_logits = candidate_logits.to(self.device)
  3203. candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1]
  3204. is_done_candidate = stopping_criteria(candidate_input_ids, None)
  3205. # 2. Use the original model to obtain the next token logits given the candidate sequence. We obtain
  3206. # `candidate_length + 1` relevant logits from this process: in the event that all candidates are correct,
  3207. # we use this forward pass to also pick the subsequent logits in the original model.
  3208. # 2.1. Prepare the model inputs
  3209. candidate_kwargs = copy.copy(model_kwargs)
  3210. candidate_kwargs = _prepare_attention_mask(
  3211. candidate_kwargs, candidate_input_ids.shape[1], self.config.is_encoder_decoder
  3212. )
  3213. candidate_kwargs = _prepare_token_type_ids(candidate_kwargs, candidate_input_ids.shape[1])
  3214. if "cache_position" in candidate_kwargs:
  3215. candidate_kwargs["cache_position"] = torch.cat(
  3216. (
  3217. candidate_kwargs["cache_position"],
  3218. torch.arange(cur_len, cur_len + candidate_length, device=input_ids.device, dtype=torch.long),
  3219. ),
  3220. dim=0,
  3221. )
  3222. model_inputs = self.prepare_inputs_for_generation(candidate_input_ids, **candidate_kwargs)
  3223. if "logits_to_keep" in model_inputs:
  3224. model_inputs["logits_to_keep"] = candidate_length + 1
  3225. # 2.2. Run a forward pass on the candidate sequence
  3226. outputs = self(**model_inputs)
  3227. # 2.3. Process the new logits
  3228. # .float() is needed to retain precision for later logits manipulations
  3229. new_logits = outputs.logits[:, -candidate_length - 1 :].to(
  3230. dtype=torch.float32, device=input_ids.device
  3231. ) # excludes the input prompt if present
  3232. next_token_logits = new_logits.clone()
  3233. if len(logits_processor) > 0:
  3234. for i in range(candidate_length + 1):
  3235. new_logits[:, i, :] = logits_processor(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :])
  3236. # 3. Select the accepted tokens. There are two possible cases:
  3237. # Case 1: `do_sample=True` and we have logits for the candidates (originally from speculative decoding)
  3238. # 👉 Apply algorithm 1 from the speculative decoding paper (https://huggingface.co/papers/2211.17192).
  3239. if do_sample and candidate_logits is not None:
  3240. valid_tokens, n_matches = _speculative_sampling(
  3241. candidate_input_ids,
  3242. candidate_logits,
  3243. candidate_length,
  3244. new_logits,
  3245. is_done_candidate,
  3246. )
  3247. # Case 2: all other cases (originally from assisted generation) 👉 Compare the tokens selected from the
  3248. # original model logits with the candidate tokens. We can keep the candidate tokens until the first
  3249. # mismatch, or until the max length is reached.
  3250. else:
  3251. if do_sample:
  3252. probs = new_logits.softmax(dim=-1)
  3253. selected_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :]
  3254. else:
  3255. selected_tokens = new_logits.argmax(dim=-1)
  3256. candidate_new_tokens = candidate_input_ids[:, cur_len:]
  3257. n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum()
  3258. # Ensure we don't generate beyond max_len or an EOS token
  3259. if is_done_candidate and n_matches == candidate_length:
  3260. n_matches -= 1
  3261. valid_tokens = selected_tokens[:, : n_matches + 1]
  3262. # 4. Update variables according to the number of matching assistant tokens. Remember: the token generated
  3263. # by the model after the last candidate match is also valid, as it is generated from a correct sequence.
  3264. # Because of this last token, assisted generation search reduces to a normal greedy search/sample if there
  3265. # is no match.
  3266. # 4.1. Get the valid continuation, after the matching tokens
  3267. input_ids = torch.cat((input_ids, valid_tokens), dim=-1)
  3268. if streamer is not None:
  3269. streamer.put(valid_tokens.cpu())
  3270. new_cur_len = input_ids.shape[1]
  3271. # 4.2. Discard past key values relative to unused assistant tokens
  3272. outputs.past_key_values.crop(new_cur_len - 1)
  3273. # 5. Update the candidate generation strategy if needed
  3274. candidate_generator.update_candidate_strategy(input_ids, new_logits, n_matches)
  3275. # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
  3276. model_kwargs = self._update_model_kwargs_for_generation(
  3277. outputs,
  3278. model_kwargs,
  3279. is_encoder_decoder=self.config.is_encoder_decoder,
  3280. num_new_tokens=n_matches + 1,
  3281. )
  3282. if synced_gpus and this_peer_finished:
  3283. continue
  3284. # Store scores, attentions and hidden_states when required
  3285. # Assistant: modified to append one tuple element per token, as in the other generation methods.
  3286. if return_dict_in_generate:
  3287. newly_added_length = n_matches + 1
  3288. if output_scores:
  3289. scores += tuple(new_logits[:, i, :] for i in range(newly_added_length))
  3290. if output_logits:
  3291. raw_logits += tuple(next_token_logits[:, i, :] for i in range(newly_added_length))
  3292. newly_added_length = new_cur_len if is_first_iteration else newly_added_length
  3293. if output_attentions:
  3294. if self.config.is_encoder_decoder:
  3295. cross_attentions = _split_model_outputs(
  3296. cross_attentions, outputs.cross_attentions, cur_len, newly_added_length
  3297. )
  3298. decoder_attentions = _split_model_outputs(
  3299. decoder_attentions,
  3300. outputs.decoder_attentions,
  3301. cur_len,
  3302. newly_added_length,
  3303. is_decoder_attention=True,
  3304. )
  3305. # some (V)LLMs have hard requirement on SDPA and thus never return attn
  3306. elif outputs.attentions[0] is not None:
  3307. decoder_attentions = _split_model_outputs(
  3308. decoder_attentions,
  3309. outputs.attentions,
  3310. cur_len,
  3311. newly_added_length,
  3312. is_decoder_attention=True,
  3313. )
  3314. if output_hidden_states:
  3315. if self.config.is_encoder_decoder:
  3316. decoder_hidden_states = _split_model_outputs(
  3317. decoder_hidden_states, outputs.decoder_hidden_states, cur_len, newly_added_length
  3318. )
  3319. else:
  3320. decoder_hidden_states = _split_model_outputs(
  3321. decoder_hidden_states, outputs.hidden_states, cur_len, newly_added_length
  3322. )
  3323. unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
  3324. this_peer_finished = unfinished_sequences.max() == 0
  3325. is_first_iteration = False
  3326. if streamer is not None:
  3327. streamer.end()
  3328. if (
  3329. hasattr(candidate_generator, "assistant_model")
  3330. and candidate_generator.assistant_model.generation_config.num_assistant_tokens_schedule == "heuristic"
  3331. ):
  3332. candidate_generator.assistant_model.generation_config.num_assistant_tokens = (
  3333. candidate_generator.num_assistant_tokens
  3334. )
  3335. if return_dict_in_generate:
  3336. if self.config.is_encoder_decoder:
  3337. return GenerateEncoderDecoderOutput(
  3338. sequences=input_ids,
  3339. scores=scores,
  3340. logits=raw_logits,
  3341. encoder_attentions=encoder_attentions,
  3342. encoder_hidden_states=encoder_hidden_states,
  3343. decoder_attentions=decoder_attentions,
  3344. cross_attentions=cross_attentions,
  3345. decoder_hidden_states=decoder_hidden_states,
  3346. past_key_values=model_kwargs.get("past_key_values"),
  3347. )
  3348. else:
  3349. return GenerateDecoderOnlyOutput(
  3350. sequences=input_ids,
  3351. scores=scores,
  3352. logits=raw_logits,
  3353. attentions=decoder_attentions,
  3354. hidden_states=decoder_hidden_states,
  3355. past_key_values=model_kwargs.get("past_key_values"),
  3356. )
  3357. else:
  3358. return input_ids
  3359. def _prefill_chunking(self, input_ids: torch.LongTensor, generation_config: GenerationConfig, **model_kwargs):
  3360. # Even if we are not compiling the forward, flex is always compiled when used. With chunk prefill, we may
  3361. # end up needing just a bit more graphs than the default (which is 8). Doing this avoids very cryptic warnings
  3362. torch._dynamo.config.cache_size_limit = 64
  3363. chunk_size = generation_config.prefill_chunk_size
  3364. # Only chunk up the token just before last, so that decoding is completely performed outside this function
  3365. # (here we simply prefill the cache)
  3366. input_chunks = torch.split(input_ids[:, :-1], chunk_size, dim=-1)
  3367. if "past_key_values" not in model_kwargs:
  3368. raise ValueError("Cannot use prefill chunking without a cache")
  3369. model_forward = self.forward
  3370. compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config)
  3371. if compile_forward:
  3372. model_forward = self.get_compiled_call(generation_config.compile_config)
  3373. attention_mask = model_kwargs.pop("attention_mask", None)
  3374. past_length = 0
  3375. for input_chunk in input_chunks:
  3376. current_length = past_length + input_chunk.shape[-1]
  3377. # Prepare inputs
  3378. if attention_mask is not None:
  3379. model_kwargs["attention_mask"] = attention_mask[:, :current_length]
  3380. model_kwargs["cache_position"] = torch.arange(
  3381. past_length, current_length, dtype=torch.long, device=input_chunk.device
  3382. )
  3383. model_kwargs["position_ids"] = model_kwargs["cache_position"].unsqueeze(0)
  3384. model_inputs = self.prepare_inputs_for_generation(input_chunk, **model_kwargs)
  3385. outputs = model_forward(**model_inputs, return_dict=True)
  3386. model_kwargs["past_key_values"] = outputs.past_key_values
  3387. past_length = current_length
  3388. model_kwargs["attention_mask"] = attention_mask
  3389. model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1
  3390. _ = model_kwargs.pop("position_ids", None)
  3391. return model_kwargs
  3392. def _speculative_sampling(
  3393. candidate_input_ids,
  3394. candidate_logits,
  3395. candidate_length,
  3396. new_logits,
  3397. is_done_candidate,
  3398. ):
  3399. """
  3400. Applies sampling as in the speculative decoding paper (https://huggingface.co/papers/2211.17192, algorithm 1). Returns
  3401. the selected tokens, as well as the number of candidate matches.
  3402. NOTE: Unless otherwise stated, the variable names match those in the paper.
  3403. """
  3404. new_candidate_input_ids = candidate_input_ids[:, -candidate_length:]
  3405. # Gets the probabilities from the logits. q_i and p_i denote the assistant and model probabilities of the tokens
  3406. # selected by the assistant, respectively.
  3407. q = candidate_logits.softmax(dim=-1)
  3408. q_i = q[:, torch.arange(candidate_length), new_candidate_input_ids].squeeze(0, 1)
  3409. p = new_logits.softmax(dim=-1)
  3410. p_i = p[:, torch.arange(candidate_length), new_candidate_input_ids].squeeze(0, 1)
  3411. probability_ratio = p_i / q_i
  3412. # When probability_ratio > 1 (i.e. q_i(x) < p_i(x), or "assistant probability of the candidate token is smaller
  3413. # than the model probability for the same token"), keep the token. Otherwise reject with p = 1 - probability_ratio
  3414. # (= keep with p = probability_ratio). Keep all the tokens until the first rejection
  3415. r_i = torch.rand_like(probability_ratio)
  3416. is_accepted = r_i <= probability_ratio
  3417. n_matches = ((~is_accepted).cumsum(dim=-1) < 1).sum() # this is `n` in algorithm 1
  3418. # Ensure we don't generate beyond max_len or an EOS token (not in algorithm 1, but needed for correct behavior)
  3419. if is_done_candidate and n_matches == candidate_length:
  3420. # Output length is assumed to be `n_matches + 1`. Since we won't generate another token with the target model
  3421. # due to acceptance on EOS we fix `n_matches`
  3422. n_matches -= 1
  3423. valid_tokens = new_candidate_input_ids[:, : n_matches + 1]
  3424. else:
  3425. # Next token selection: if there is a rejection, adjust the distribution from the main model before sampling.
  3426. gamma = candidate_logits.shape[1]
  3427. p_n_plus_1 = p[:, n_matches, :]
  3428. if n_matches < gamma:
  3429. q_n_plus_1 = q[:, n_matches, :]
  3430. p_prime = torch.clamp((p_n_plus_1 - q_n_plus_1), min=0)
  3431. p_prime.div_(p_prime.sum())
  3432. else:
  3433. p_prime = p_n_plus_1
  3434. t = torch.multinomial(p_prime, num_samples=1).squeeze(1)[None, :]
  3435. # The selected tokens include the matches (if any) plus the next sampled tokens
  3436. if n_matches > 0:
  3437. valid_tokens = torch.cat((new_candidate_input_ids[:, :n_matches], t), dim=-1)
  3438. else:
  3439. valid_tokens = t
  3440. return valid_tokens, n_matches
  3441. def _split_model_outputs(outputs, new_outputs, cur_len, added_len, is_decoder_attention=False):
  3442. """
  3443. Given the (decoder/cross attentions)/(decoder hidden states) for multiple generated tokens, splits it into a tuple
  3444. where each member corresponds to a single generated token.
  3445. """
  3446. # Retrocompatibility: in our generation functions, the first iteration includes the attention/hidden states for the
  3447. # prompt.
  3448. if len(outputs) == 0:
  3449. new_tuple = ()
  3450. for layer in new_outputs:
  3451. last_dim_size = cur_len if is_decoder_attention else layer.shape[-1]
  3452. new_tuple += (layer[..., :cur_len, :last_dim_size],)
  3453. outputs += (new_tuple,)
  3454. # The first iteration contains the prompt + 1 generated token, let's update the length variables accordingly
  3455. cur_len += 1
  3456. added_len -= cur_len
  3457. for i in range(added_len):
  3458. new_tuple = ()
  3459. for layer in new_outputs:
  3460. last_dim_size = cur_len + i if is_decoder_attention else layer.shape[-1]
  3461. new_tuple += (layer[..., i : i + 1, :last_dim_size],)
  3462. outputs += (new_tuple,)
  3463. return outputs