modeling_perceiver.py 134 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327332833293330333133323333333433353336333733383339334033413342334333443345334633473348
  1. # coding=utf-8
  2. # Copyright 2021 Deepmind and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch Perceiver model."""
  16. import abc
  17. import math
  18. from collections.abc import Mapping
  19. from dataclasses import dataclass
  20. from functools import reduce
  21. from operator import __add__
  22. from typing import Any, Callable, Optional, Union
  23. import numpy as np
  24. import torch
  25. from torch import nn
  26. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  27. from ...activations import ACT2FN
  28. from ...modeling_outputs import BaseModelOutputWithCrossAttentions
  29. from ...modeling_utils import PreTrainedModel
  30. from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, meshgrid, prune_linear_layer
  31. from ...utils import ModelOutput, auto_docstring, logging, torch_int
  32. from .configuration_perceiver import PerceiverConfig
  33. ModalitySizeType = Mapping[str, int]
  34. PreprocessorOutputType = tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]
  35. PreprocessorType = Callable[..., PreprocessorOutputType]
  36. PostprocessorType = Callable[..., Any]
  37. logger = logging.get_logger(__name__)
  38. @dataclass
  39. @auto_docstring(
  40. custom_intro="""
  41. Base class for Perceiver base model's outputs, with potential hidden states, attentions and cross-attentions.
  42. """
  43. )
  44. class PerceiverModelOutput(ModelOutput):
  45. r"""
  46. logits (`torch.FloatTensor` of shape `(batch_size, num_labels)`):
  47. Classification (or regression if config.num_labels==1) scores (before SoftMax).
  48. """
  49. logits: Optional[torch.FloatTensor] = None
  50. last_hidden_state: Optional[torch.FloatTensor] = None
  51. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  52. attentions: Optional[tuple[torch.FloatTensor]] = None
  53. cross_attentions: Optional[tuple[torch.FloatTensor]] = None
  54. @dataclass
  55. @auto_docstring(
  56. custom_intro="""
  57. Base class for Perceiver decoder outputs, with potential cross-attentions.
  58. """
  59. )
  60. class PerceiverDecoderOutput(ModelOutput):
  61. r"""
  62. logits (`torch.FloatTensor` of shape `(batch_size, num_labels)`):
  63. Output of the basic decoder.
  64. """
  65. logits: Optional[torch.FloatTensor] = None
  66. cross_attentions: Optional[tuple[torch.FloatTensor]] = None
  67. @dataclass
  68. @auto_docstring(
  69. custom_intro="""
  70. Base class for Perceiver's masked language model outputs.
  71. """
  72. )
  73. class PerceiverMaskedLMOutput(ModelOutput):
  74. r"""
  75. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  76. Masked language modeling (MLM) loss.
  77. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  78. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  79. """
  80. loss: Optional[torch.FloatTensor] = None
  81. logits: Optional[torch.FloatTensor] = None
  82. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  83. attentions: Optional[tuple[torch.FloatTensor]] = None
  84. cross_attentions: Optional[tuple[torch.FloatTensor]] = None
  85. @dataclass
  86. @auto_docstring(
  87. custom_intro="""
  88. Base class for Perceiver's outputs of sequence/image classification models, optical flow and multimodal
  89. autoencoding.
  90. """
  91. )
  92. class PerceiverClassifierOutput(ModelOutput):
  93. r"""
  94. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  95. Classification (or regression if config.num_labels==1) loss.
  96. logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
  97. Classification (or regression if config.num_labels==1) scores (before SoftMax).
  98. """
  99. loss: Optional[torch.FloatTensor] = None
  100. logits: Optional[torch.FloatTensor] = None
  101. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  102. attentions: Optional[tuple[torch.FloatTensor]] = None
  103. cross_attentions: Optional[tuple[torch.FloatTensor]] = None
  104. class PerceiverEmbeddings(nn.Module):
  105. """Construct the latent embeddings."""
  106. def __init__(self, config):
  107. super().__init__()
  108. self.latents = nn.Parameter(torch.randn(config.num_latents, config.d_latents))
  109. def forward(self, batch_size: int):
  110. return self.latents.expand(batch_size, -1, -1) # Thanks, Phil Wang
  111. class PerceiverSelfAttention(nn.Module):
  112. """Multi-headed {cross, self}-attention. Can be used both in the encoder as well as in the decoder."""
  113. def __init__(
  114. self,
  115. config,
  116. is_cross_attention=False,
  117. qk_channels=None,
  118. v_channels=None,
  119. num_heads=1,
  120. q_dim=None,
  121. kv_dim=None,
  122. ):
  123. super().__init__()
  124. self.num_heads = num_heads
  125. # Q and K must have the same number of channels.
  126. # Default to preserving Q's input's shape.
  127. if qk_channels is None:
  128. qk_channels = q_dim
  129. # V's num_channels determines the shape of the output of QKV-attention.
  130. # Default to the same number of channels used in the key-query operation.
  131. if v_channels is None:
  132. v_channels = qk_channels
  133. if qk_channels % num_heads != 0:
  134. raise ValueError(f"qk_channels ({qk_channels}) must be divisible by num_heads ({num_heads}).")
  135. if v_channels % num_heads != 0:
  136. raise ValueError(f"v_channels ({v_channels}) must be divisible by num_heads ({num_heads}).")
  137. self.qk_channels = qk_channels
  138. self.v_channels = v_channels
  139. self.qk_channels_per_head = self.qk_channels // num_heads
  140. self.v_channels_per_head = self.v_channels // num_heads
  141. # Layer normalization
  142. self.layernorm1 = nn.LayerNorm(q_dim)
  143. self.layernorm2 = nn.LayerNorm(kv_dim) if is_cross_attention else nn.Identity()
  144. # Projection matrices
  145. self.query = nn.Linear(q_dim, qk_channels)
  146. self.key = nn.Linear(kv_dim, qk_channels)
  147. self.value = nn.Linear(kv_dim, v_channels)
  148. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  149. def transpose_for_scores(self, x, channels_per_head):
  150. new_x_shape = x.size()[:-1] + (self.num_heads, channels_per_head)
  151. x = x.view(*new_x_shape)
  152. return x.permute(0, 2, 1, 3)
  153. def forward(
  154. self,
  155. hidden_states: torch.Tensor,
  156. attention_mask: Optional[torch.FloatTensor] = None,
  157. head_mask: Optional[torch.FloatTensor] = None,
  158. inputs: Optional[torch.FloatTensor] = None,
  159. inputs_mask: Optional[torch.FloatTensor] = None,
  160. output_attentions: Optional[bool] = False,
  161. ) -> tuple[torch.Tensor]:
  162. hidden_states = self.layernorm1(hidden_states)
  163. inputs = self.layernorm2(inputs)
  164. # Project queries, keys and values to a common feature dimension. If this is instantiated as a cross-attention module,
  165. # the keys and values come from the inputs; the attention mask needs to be such that the inputs's non-relevant tokens are not attended to.
  166. is_cross_attention = inputs is not None
  167. queries = self.query(hidden_states)
  168. if is_cross_attention:
  169. keys = self.key(inputs)
  170. values = self.value(inputs)
  171. attention_mask = inputs_mask
  172. else:
  173. keys = self.key(hidden_states)
  174. values = self.value(hidden_states)
  175. # Reshape channels for multi-head attention.
  176. # We reshape from (batch_size, time, channels) to (batch_size, num_heads, time, channels per head)
  177. queries = self.transpose_for_scores(queries, self.qk_channels_per_head)
  178. keys = self.transpose_for_scores(keys, self.qk_channels_per_head)
  179. values = self.transpose_for_scores(values, self.v_channels_per_head)
  180. # Take the dot product between the queries and keys to get the raw attention scores.
  181. attention_scores = torch.matmul(queries, keys.transpose(-1, -2))
  182. batch_size, num_heads, seq_len, q_head_dim = queries.shape
  183. _, _, _, v_head_dim = values.shape
  184. hiddens = self.num_heads * v_head_dim
  185. attention_scores = attention_scores / math.sqrt(q_head_dim)
  186. if attention_mask is not None:
  187. # Apply the attention mask (precomputed for all layers in PerceiverModel forward() function)
  188. attention_scores = attention_scores + attention_mask
  189. # Normalize the attention scores to probabilities.
  190. attention_probs = nn.Softmax(dim=-1)(attention_scores)
  191. # This is actually dropping out entire tokens to attend to, which might
  192. # seem a bit unusual, but is taken from the original Transformer paper.
  193. attention_probs = self.dropout(attention_probs)
  194. # Mask heads if we want to
  195. if head_mask is not None:
  196. attention_probs = attention_probs * head_mask
  197. context_layer = torch.matmul(attention_probs, values)
  198. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  199. new_context_layer_shape = context_layer.size()[:-2] + (hiddens,)
  200. context_layer = context_layer.view(*new_context_layer_shape)
  201. outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
  202. return outputs
  203. class PerceiverSelfOutput(nn.Module):
  204. def __init__(self, config, input_channels, output_channels):
  205. super().__init__()
  206. self.dense = nn.Linear(input_channels, output_channels)
  207. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  208. hidden_states = self.dense(hidden_states)
  209. return hidden_states
  210. class PerceiverAttention(nn.Module):
  211. """Attention module, including a dense block."""
  212. def __init__(
  213. self,
  214. config,
  215. is_cross_attention=False,
  216. qk_channels=None,
  217. v_channels=None,
  218. num_heads=1,
  219. q_dim=None,
  220. kv_dim=None,
  221. use_query_residual=True,
  222. ):
  223. super().__init__()
  224. # MultiHead attention
  225. if is_cross_attention and qk_channels is None:
  226. if config.cross_attention_shape_for_attention == "q":
  227. qk_channels = q_dim
  228. elif config.cross_attention_shape_for_attention == "kv":
  229. qk_channels = kv_dim
  230. else:
  231. raise ValueError(
  232. f"Unknown value {config.cross_attention_shape_for_attention} for "
  233. "cross_attention_shape_for_attention."
  234. )
  235. else:
  236. if qk_channels is None:
  237. qk_channels = q_dim
  238. if v_channels is None:
  239. v_channels = qk_channels
  240. self.self = PerceiverSelfAttention(
  241. config,
  242. is_cross_attention=is_cross_attention,
  243. qk_channels=qk_channels,
  244. v_channels=v_channels,
  245. num_heads=num_heads,
  246. q_dim=q_dim,
  247. kv_dim=kv_dim,
  248. )
  249. # dense block
  250. output_channels = None
  251. if is_cross_attention:
  252. output_channels = q_dim
  253. else:
  254. if output_channels is None:
  255. output_channels = v_channels
  256. self.output = PerceiverSelfOutput(config, input_channels=self.self.v_channels, output_channels=output_channels)
  257. self.use_query_residual = use_query_residual
  258. self.pruned_heads = set()
  259. def prune_heads(self, heads):
  260. if len(heads) == 0:
  261. return
  262. heads, index = find_pruneable_heads_and_indices(
  263. heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
  264. )
  265. # Prune linear layers
  266. self.self.query = prune_linear_layer(self.self.query, index)
  267. self.self.key = prune_linear_layer(self.self.key, index)
  268. self.self.value = prune_linear_layer(self.self.value, index)
  269. self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
  270. # Update hyper params and store pruned heads
  271. self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
  272. self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
  273. self.pruned_heads = self.pruned_heads.union(heads)
  274. def forward(
  275. self,
  276. hidden_states: torch.Tensor,
  277. attention_mask: Optional[torch.FloatTensor] = None,
  278. head_mask: Optional[torch.FloatTensor] = None,
  279. inputs: Optional[torch.FloatTensor] = None,
  280. inputs_mask: Optional[torch.FloatTensor] = None,
  281. output_attentions: Optional[bool] = False,
  282. ) -> tuple[torch.Tensor]:
  283. self_outputs = self.self(
  284. hidden_states,
  285. attention_mask,
  286. head_mask,
  287. inputs,
  288. inputs_mask,
  289. output_attentions,
  290. )
  291. # Output projection
  292. attention_output = self.output(self_outputs[0])
  293. # Optionally include a residual to the original queries.
  294. # Consider omitting the residual if the semantics of query and output
  295. # are different, e.g. if queries are positions and outputs are pixels.
  296. if self.use_query_residual:
  297. attention_output = attention_output + hidden_states
  298. outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
  299. return outputs
  300. class PerceiverMLP(nn.Module):
  301. """A Transformer-style dense module to follow attention."""
  302. def __init__(self, config, input_size, widening_factor):
  303. super().__init__()
  304. self.dense1 = nn.Linear(input_size, widening_factor * input_size)
  305. if isinstance(config.hidden_act, str):
  306. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  307. else:
  308. self.intermediate_act_fn = config.hidden_act
  309. self.dense2 = nn.Linear(widening_factor * input_size, input_size)
  310. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  311. hidden_states = self.dense1(hidden_states)
  312. hidden_states = self.intermediate_act_fn(hidden_states)
  313. hidden_states = self.dense2(hidden_states)
  314. return hidden_states
  315. class PerceiverLayer(nn.Module):
  316. def __init__(
  317. self,
  318. config,
  319. is_cross_attention=False,
  320. qk_channels=None,
  321. v_channels=None,
  322. num_heads=1,
  323. q_dim=None,
  324. kv_dim=None,
  325. widening_factor=4,
  326. use_query_residual=True,
  327. ):
  328. super().__init__()
  329. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  330. self.seq_len_dim = 1
  331. self.attention = PerceiverAttention(
  332. config,
  333. is_cross_attention=is_cross_attention,
  334. qk_channels=qk_channels,
  335. v_channels=v_channels,
  336. num_heads=num_heads,
  337. q_dim=q_dim,
  338. kv_dim=kv_dim,
  339. use_query_residual=use_query_residual,
  340. )
  341. self.layernorm = nn.LayerNorm(q_dim)
  342. self.mlp = PerceiverMLP(config, input_size=q_dim, widening_factor=widening_factor)
  343. def forward(
  344. self,
  345. hidden_states: torch.Tensor,
  346. attention_mask: Optional[torch.FloatTensor] = None,
  347. head_mask: Optional[torch.FloatTensor] = None,
  348. inputs: Optional[torch.FloatTensor] = None,
  349. inputs_mask: Optional[torch.FloatTensor] = None,
  350. output_attentions: Optional[bool] = False,
  351. ) -> tuple[torch.Tensor]:
  352. attention_outputs = self.attention(
  353. hidden_states,
  354. attention_mask,
  355. head_mask,
  356. inputs,
  357. inputs_mask,
  358. output_attentions,
  359. )
  360. attention_output = attention_outputs[0]
  361. outputs = attention_outputs[1:] # add attentions if we output attention weights
  362. layer_output = apply_chunking_to_forward(
  363. self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
  364. )
  365. layer_output = layer_output + attention_output # residual connection
  366. outputs = (layer_output,) + outputs
  367. return outputs
  368. def feed_forward_chunk(self, attention_output):
  369. layer_output = self.layernorm(attention_output)
  370. layer_output = self.mlp(layer_output)
  371. return layer_output
  372. class PerceiverEncoder(nn.Module):
  373. """The Perceiver Encoder: a scalable, fully attentional encoder."""
  374. def __init__(self, config, kv_dim=None):
  375. super().__init__()
  376. self.config = config
  377. # Check that we can use multihead-attention with these shapes.
  378. if config.d_latents % config.num_self_attention_heads != 0:
  379. raise ValueError(
  380. f"num_z_channels ({config.d_latents}) must be divisible by"
  381. f" num_self_attend_heads ({config.num_self_attention_heads})."
  382. )
  383. if config.d_latents % config.num_cross_attention_heads != 0:
  384. raise ValueError(
  385. f"num_z_channels ({config.d_latents}) must be divisible by"
  386. f" num_cross_attend_heads ({config.num_cross_attention_heads})."
  387. )
  388. # Construct the cross attention layer.
  389. self.cross_attention = PerceiverLayer(
  390. config,
  391. is_cross_attention=True,
  392. qk_channels=config.qk_channels,
  393. v_channels=config.v_channels,
  394. num_heads=config.num_cross_attention_heads,
  395. q_dim=config.d_latents,
  396. kv_dim=kv_dim,
  397. widening_factor=config.cross_attention_widening_factor,
  398. use_query_residual=config.use_query_residual,
  399. )
  400. # Construct a single block of self-attention layers.
  401. # We get deeper architectures by applying this block more than once.
  402. self_attention_layers = []
  403. for _ in range(config.num_self_attends_per_block):
  404. layer = PerceiverLayer(
  405. config,
  406. is_cross_attention=False,
  407. qk_channels=config.qk_channels,
  408. v_channels=config.v_channels,
  409. num_heads=config.num_self_attention_heads,
  410. q_dim=config.d_latents,
  411. kv_dim=config.d_latents,
  412. widening_factor=config.self_attention_widening_factor,
  413. )
  414. self_attention_layers.append(layer)
  415. self.self_attends = nn.ModuleList(self_attention_layers)
  416. def forward(
  417. self,
  418. hidden_states: torch.Tensor,
  419. attention_mask: Optional[torch.FloatTensor] = None,
  420. head_mask: Optional[torch.FloatTensor] = None,
  421. inputs: Optional[torch.FloatTensor] = None,
  422. inputs_mask: Optional[torch.FloatTensor] = None,
  423. output_attentions: Optional[bool] = False,
  424. output_hidden_states: Optional[bool] = False,
  425. return_dict: Optional[bool] = True,
  426. ) -> Union[tuple, BaseModelOutputWithCrossAttentions]:
  427. all_hidden_states = () if output_hidden_states else None
  428. all_self_attentions = () if output_attentions else None
  429. all_cross_attentions = () if output_attentions else None
  430. # Apply the cross-attention between the latents (hidden_states) and inputs:
  431. layer_outputs = self.cross_attention(
  432. hidden_states,
  433. attention_mask=attention_mask,
  434. head_mask=None,
  435. inputs=inputs,
  436. inputs_mask=inputs_mask,
  437. output_attentions=output_attentions,
  438. )
  439. hidden_states = layer_outputs[0]
  440. if output_attentions:
  441. all_cross_attentions = all_cross_attentions + (layer_outputs[1],)
  442. # Apply the block of self-attention layers more than once:
  443. for _ in range(self.config.num_blocks):
  444. for i, layer_module in enumerate(self.self_attends):
  445. if output_hidden_states:
  446. all_hidden_states = all_hidden_states + (hidden_states,)
  447. layer_head_mask = head_mask[i] if head_mask is not None else None
  448. layer_outputs = layer_module(
  449. hidden_states,
  450. attention_mask=attention_mask,
  451. head_mask=layer_head_mask,
  452. output_attentions=output_attentions,
  453. )
  454. hidden_states = layer_outputs[0]
  455. if output_attentions:
  456. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  457. if output_hidden_states:
  458. all_hidden_states = all_hidden_states + (hidden_states,)
  459. if not return_dict:
  460. return tuple(
  461. v
  462. for v in [hidden_states, all_hidden_states, all_self_attentions, all_cross_attentions]
  463. if v is not None
  464. )
  465. return BaseModelOutputWithCrossAttentions(
  466. last_hidden_state=hidden_states,
  467. hidden_states=all_hidden_states,
  468. attentions=all_self_attentions,
  469. cross_attentions=all_cross_attentions,
  470. )
  471. @auto_docstring
  472. class PerceiverPreTrainedModel(PreTrainedModel):
  473. config: PerceiverConfig
  474. base_model_prefix = "perceiver"
  475. main_input_name = "inputs"
  476. def _init_weights(self, module):
  477. """Initialize the weights"""
  478. if isinstance(module, (nn.Linear, nn.Conv2d)):
  479. # Slightly different from the TF version which uses truncated_normal for initialization
  480. # cf https://github.com/pytorch/pytorch/pull/5617
  481. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  482. if module.bias is not None:
  483. module.bias.data.zero_()
  484. elif hasattr(module, "latents"):
  485. module.latents.data.normal_(mean=0.0, std=self.config.initializer_range)
  486. elif hasattr(module, "position_embeddings") and isinstance(module, PerceiverTrainablePositionEncoding):
  487. module.position_embeddings.data.normal_(mean=0.0, std=self.config.initializer_range)
  488. elif isinstance(module, nn.ParameterDict):
  489. for modality in module:
  490. module[modality].data.normal_(mean=0.0, std=self.config.initializer_range)
  491. elif isinstance(module, nn.Embedding):
  492. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  493. if module.padding_idx is not None:
  494. module.weight.data[module.padding_idx].zero_()
  495. elif isinstance(module, nn.LayerNorm):
  496. module.bias.data.zero_()
  497. module.weight.data.fill_(1.0)
  498. @auto_docstring(
  499. custom_intro="""
  500. The Perceiver: a scalable, fully attentional architecture.
  501. <Tip>
  502. Note that it's possible to fine-tune Perceiver on higher resolution images than the ones it has been trained on, by
  503. setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
  504. position embeddings to the higher resolution.
  505. </Tip>
  506. """
  507. )
  508. class PerceiverModel(PerceiverPreTrainedModel):
  509. def __init__(
  510. self,
  511. config,
  512. decoder: Optional["PerceiverAbstractDecoder"] = None,
  513. input_preprocessor: PreprocessorType = None,
  514. output_postprocessor: PostprocessorType = None,
  515. ):
  516. r"""
  517. decoder (`PerceiverDecoder`, *optional*):
  518. Decoder module that transforms latent representations into task predictions.
  519. input_preprocessor (`PreprocessorType`, *optional*):
  520. Preprocessor that encodes raw inputs into tensors for the model.
  521. output_postprocessor (`PostprocessorType`, *optional*):
  522. Postprocessor that transforms model outputs into final predictions.
  523. """
  524. super().__init__(config)
  525. self.config = config
  526. self.input_preprocessor = input_preprocessor
  527. self.output_postprocessor = output_postprocessor
  528. self.embeddings = PerceiverEmbeddings(config)
  529. self.encoder = PerceiverEncoder(
  530. config, kv_dim=input_preprocessor.num_channels if input_preprocessor is not None else config.d_model
  531. )
  532. self.decoder = decoder
  533. # Initialize weights and apply final processing
  534. self.post_init()
  535. def get_input_embeddings(self):
  536. return self.embeddings.latents
  537. def set_input_embeddings(self, value):
  538. self.embeddings.latents = value
  539. def _prune_heads(self, heads_to_prune):
  540. """
  541. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  542. class PreTrainedModel
  543. """
  544. for layer, heads in heads_to_prune.items():
  545. self.encoder.layer[layer].attention.prune_heads(heads)
  546. @auto_docstring
  547. def forward(
  548. self,
  549. inputs: torch.FloatTensor,
  550. attention_mask: Optional[torch.FloatTensor] = None,
  551. subsampled_output_points: Optional[dict[str, torch.Tensor]] = None,
  552. head_mask: Optional[torch.FloatTensor] = None,
  553. output_attentions: Optional[bool] = None,
  554. output_hidden_states: Optional[bool] = None,
  555. interpolate_pos_encoding: bool = False,
  556. return_dict: Optional[bool] = None,
  557. ) -> Union[tuple, PerceiverModelOutput]:
  558. r"""
  559. inputs (`torch.FloatTensor`):
  560. Inputs to the perceiver. Can be anything: images, text, audio, video, etc.
  561. subsampled_output_points (`dict[str, torch.Tensor]`, *optional*):
  562. Dictionary of tensors used as queries for the decoder. The decoder maps these queries to the latent
  563. representation of the model. Used for subsampled decoding, e.g. when only decoding certain image patches.
  564. Examples:
  565. ```python
  566. >>> from transformers import PerceiverConfig, PerceiverTokenizer, PerceiverImageProcessor, PerceiverModel
  567. >>> from transformers.models.perceiver.modeling_perceiver import (
  568. ... PerceiverTextPreprocessor,
  569. ... PerceiverImagePreprocessor,
  570. ... PerceiverClassificationDecoder,
  571. ... )
  572. >>> import torch
  573. >>> import requests
  574. >>> from PIL import Image
  575. >>> # EXAMPLE 1: using the Perceiver to classify texts
  576. >>> # - we define a TextPreprocessor, which can be used to embed tokens
  577. >>> # - we define a ClassificationDecoder, which can be used to decode the
  578. >>> # final hidden states of the latents to classification logits
  579. >>> # using trainable position embeddings
  580. >>> config = PerceiverConfig()
  581. >>> preprocessor = PerceiverTextPreprocessor(config)
  582. >>> decoder = PerceiverClassificationDecoder(
  583. ... config,
  584. ... num_channels=config.d_latents,
  585. ... trainable_position_encoding_kwargs=dict(num_channels=config.d_latents, index_dims=1),
  586. ... use_query_residual=True,
  587. ... )
  588. >>> model = PerceiverModel(config, input_preprocessor=preprocessor, decoder=decoder)
  589. >>> # you can then do a forward pass as follows:
  590. >>> tokenizer = PerceiverTokenizer()
  591. >>> text = "hello world"
  592. >>> inputs = tokenizer(text, return_tensors="pt").input_ids
  593. >>> with torch.no_grad():
  594. ... outputs = model(inputs=inputs)
  595. >>> logits = outputs.logits
  596. >>> list(logits.shape)
  597. [1, 2]
  598. >>> # to train, one can train the model using standard cross-entropy:
  599. >>> criterion = torch.nn.CrossEntropyLoss()
  600. >>> labels = torch.tensor([1])
  601. >>> loss = criterion(logits, labels)
  602. >>> # EXAMPLE 2: using the Perceiver to classify images
  603. >>> # - we define an ImagePreprocessor, which can be used to embed images
  604. >>> config = PerceiverConfig(image_size=224)
  605. >>> preprocessor = PerceiverImagePreprocessor(
  606. ... config,
  607. ... prep_type="conv1x1",
  608. ... spatial_downsample=1,
  609. ... out_channels=256,
  610. ... position_encoding_type="trainable",
  611. ... concat_or_add_pos="concat",
  612. ... project_pos_dim=256,
  613. ... trainable_position_encoding_kwargs=dict(
  614. ... num_channels=256,
  615. ... index_dims=config.image_size**2,
  616. ... ),
  617. ... )
  618. >>> model = PerceiverModel(
  619. ... config,
  620. ... input_preprocessor=preprocessor,
  621. ... decoder=PerceiverClassificationDecoder(
  622. ... config,
  623. ... num_channels=config.d_latents,
  624. ... trainable_position_encoding_kwargs=dict(num_channels=config.d_latents, index_dims=1),
  625. ... use_query_residual=True,
  626. ... ),
  627. ... )
  628. >>> # you can then do a forward pass as follows:
  629. >>> image_processor = PerceiverImageProcessor()
  630. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  631. >>> image = Image.open(requests.get(url, stream=True).raw)
  632. >>> inputs = image_processor(image, return_tensors="pt").pixel_values
  633. >>> with torch.no_grad():
  634. ... outputs = model(inputs=inputs)
  635. >>> logits = outputs.logits
  636. >>> list(logits.shape)
  637. [1, 2]
  638. >>> # to train, one can train the model using standard cross-entropy:
  639. >>> criterion = torch.nn.CrossEntropyLoss()
  640. >>> labels = torch.tensor([1])
  641. >>> loss = criterion(logits, labels)
  642. ```"""
  643. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  644. output_hidden_states = (
  645. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  646. )
  647. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  648. if self.input_preprocessor is not None:
  649. inputs, modality_sizes, inputs_without_pos = self.input_preprocessor(
  650. inputs, interpolate_pos_encoding=interpolate_pos_encoding
  651. )
  652. else:
  653. modality_sizes = None
  654. inputs_without_pos = None
  655. if inputs.size()[-1] != self.config.d_model:
  656. raise ValueError(
  657. f"Last dimension of the inputs: {inputs.size()[-1]} doesn't correspond to config.d_model:"
  658. f" {self.config.d_model}. Make sure to set config.d_model appropriately."
  659. )
  660. batch_size, seq_length, _ = inputs.size()
  661. device = inputs.device
  662. # If no attention mask is provided, make them all ones
  663. if attention_mask is None:
  664. attention_mask = torch.ones((batch_size, seq_length), device=device)
  665. # Make the attention mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
  666. extended_attention_mask = self.invert_attention_mask(attention_mask)
  667. # Prepare head mask if needed
  668. # 1.0 in head_mask indicate we keep the head
  669. # attention_probs has shape bsz x n_heads x N x N
  670. # input head_mask has shape [num_heads] or [num_blocks x num_heads]
  671. # and head_mask is converted to shape [num_blocks x batch x num_heads x N x N]
  672. head_mask = self.get_head_mask(head_mask, self.config.num_blocks * self.config.num_self_attends_per_block)
  673. embedding_output = self.embeddings(batch_size=batch_size)
  674. encoder_outputs = self.encoder(
  675. embedding_output,
  676. attention_mask=None,
  677. head_mask=head_mask,
  678. inputs=inputs,
  679. inputs_mask=extended_attention_mask,
  680. output_attentions=output_attentions,
  681. output_hidden_states=output_hidden_states,
  682. return_dict=return_dict,
  683. )
  684. sequence_output = encoder_outputs[0]
  685. logits = None
  686. if self.decoder:
  687. if subsampled_output_points is not None:
  688. output_modality_sizes = {
  689. "audio": subsampled_output_points["audio"].shape[0],
  690. "image": subsampled_output_points["image"].shape[0],
  691. "label": 1,
  692. }
  693. else:
  694. output_modality_sizes = modality_sizes
  695. decoder_query = self.decoder.decoder_query(
  696. inputs, modality_sizes, inputs_without_pos, subsampled_points=subsampled_output_points
  697. )
  698. decoder_outputs = self.decoder(
  699. decoder_query,
  700. z=sequence_output,
  701. query_mask=extended_attention_mask,
  702. output_attentions=output_attentions,
  703. )
  704. logits = decoder_outputs.logits
  705. # add cross-attentions of decoder
  706. if output_attentions and decoder_outputs.cross_attentions is not None:
  707. if return_dict:
  708. encoder_outputs.cross_attentions = (
  709. encoder_outputs.cross_attentions + decoder_outputs.cross_attentions
  710. )
  711. else:
  712. encoder_outputs = encoder_outputs + decoder_outputs.cross_attentions
  713. if self.output_postprocessor:
  714. logits = self.output_postprocessor(logits, modality_sizes=output_modality_sizes)
  715. if not return_dict:
  716. if logits is not None:
  717. return (logits, sequence_output) + encoder_outputs[1:]
  718. else:
  719. return (sequence_output,) + encoder_outputs[1:]
  720. return PerceiverModelOutput(
  721. logits=logits,
  722. last_hidden_state=sequence_output,
  723. hidden_states=encoder_outputs.hidden_states,
  724. attentions=encoder_outputs.attentions,
  725. cross_attentions=encoder_outputs.cross_attentions,
  726. )
  727. @auto_docstring(
  728. custom_intro="""
  729. Example use of Perceiver for masked language modeling.
  730. """
  731. )
  732. class PerceiverForMaskedLM(PerceiverPreTrainedModel):
  733. def __init__(self, config: PerceiverConfig):
  734. super().__init__(config)
  735. text_preprocessor = PerceiverTextPreprocessor(config)
  736. trainable_position_encoding_kwargs_decoder = {
  737. "num_channels": text_preprocessor.num_channels,
  738. "index_dims": config.max_position_embeddings,
  739. }
  740. self.perceiver = PerceiverModel(
  741. config,
  742. input_preprocessor=text_preprocessor,
  743. decoder=PerceiverBasicDecoder(
  744. config,
  745. output_num_channels=config.d_latents,
  746. output_index_dims=config.max_position_embeddings, # we need to define the seq_len of the inputs beforehand
  747. num_channels=text_preprocessor.num_channels,
  748. qk_channels=8 * 32,
  749. v_channels=text_preprocessor.num_channels,
  750. num_heads=8,
  751. use_query_residual=False,
  752. final_project=False,
  753. trainable_position_encoding_kwargs=trainable_position_encoding_kwargs_decoder,
  754. ),
  755. )
  756. self.embedding_decoder = PerceiverEmbeddingDecoder(config)
  757. # Initialize weights and apply final processing
  758. self.post_init()
  759. @auto_docstring
  760. def forward(
  761. self,
  762. inputs: Optional[torch.Tensor] = None,
  763. attention_mask: Optional[torch.Tensor] = None,
  764. head_mask: Optional[torch.Tensor] = None,
  765. output_attentions: Optional[bool] = None,
  766. output_hidden_states: Optional[bool] = None,
  767. labels: Optional[torch.Tensor] = None,
  768. return_dict: Optional[bool] = None,
  769. input_ids: Optional[torch.Tensor] = None,
  770. ) -> Union[tuple, PerceiverMaskedLMOutput]:
  771. r"""
  772. inputs (`torch.FloatTensor`):
  773. Inputs to the perceiver. Can be anything: images, text, audio, video, etc.
  774. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  775. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  776. config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
  777. loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  778. Examples:
  779. ```python
  780. >>> from transformers import AutoTokenizer, PerceiverForMaskedLM
  781. >>> import torch
  782. >>> tokenizer = AutoTokenizer.from_pretrained("deepmind/language-perceiver")
  783. >>> model = PerceiverForMaskedLM.from_pretrained("deepmind/language-perceiver")
  784. >>> # training
  785. >>> text = "This is an incomplete sentence where some words are missing."
  786. >>> inputs = tokenizer(text, padding="max_length", return_tensors="pt")
  787. >>> # mask " missing."
  788. >>> inputs["input_ids"][0, 52:61] = tokenizer.mask_token_id
  789. >>> labels = tokenizer(text, padding="max_length", return_tensors="pt").input_ids
  790. >>> outputs = model(**inputs, labels=labels)
  791. >>> loss = outputs.loss
  792. >>> round(loss.item(), 2)
  793. 19.87
  794. >>> logits = outputs.logits
  795. >>> list(logits.shape)
  796. [1, 2048, 262]
  797. >>> # inference
  798. >>> text = "This is an incomplete sentence where some words are missing."
  799. >>> encoding = tokenizer(text, padding="max_length", return_tensors="pt")
  800. >>> # mask bytes corresponding to " missing.". Note that the model performs much better if the masked span starts with a space.
  801. >>> encoding["input_ids"][0, 52:61] = tokenizer.mask_token_id
  802. >>> # forward pass
  803. >>> with torch.no_grad():
  804. ... outputs = model(**encoding)
  805. >>> logits = outputs.logits
  806. >>> list(logits.shape)
  807. [1, 2048, 262]
  808. >>> masked_tokens_predictions = logits[0, 52:61].argmax(dim=-1).tolist()
  809. >>> tokenizer.decode(masked_tokens_predictions)
  810. ' missing.'
  811. ```"""
  812. if inputs is not None and input_ids is not None:
  813. raise ValueError("You cannot use both `inputs` and `input_ids`")
  814. elif inputs is None and input_ids is not None:
  815. inputs = input_ids
  816. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  817. outputs = self.perceiver(
  818. inputs=inputs,
  819. attention_mask=attention_mask,
  820. head_mask=head_mask,
  821. output_attentions=output_attentions,
  822. output_hidden_states=output_hidden_states,
  823. return_dict=return_dict,
  824. )
  825. logits = self.embedding_decoder(
  826. outputs.logits if return_dict else outputs[0], embedding_layer=self.perceiver.input_preprocessor.embeddings
  827. )
  828. masked_lm_loss = None
  829. if labels is not None:
  830. loss_fct = CrossEntropyLoss() # -100 index = padding token
  831. masked_lm_loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
  832. if not return_dict:
  833. output = (logits,) + outputs[2:]
  834. return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
  835. return PerceiverMaskedLMOutput(
  836. loss=masked_lm_loss,
  837. logits=logits,
  838. hidden_states=outputs.hidden_states,
  839. attentions=outputs.attentions,
  840. cross_attentions=outputs.cross_attentions,
  841. )
  842. @auto_docstring(
  843. custom_intro="""
  844. Example use of Perceiver for text classification.
  845. """
  846. )
  847. class PerceiverForSequenceClassification(PerceiverPreTrainedModel):
  848. def __init__(self, config):
  849. super().__init__(config)
  850. trainable_position_encoding_kwargs_decoder = {"num_channels": config.d_latents, "index_dims": 1}
  851. self.num_labels = config.num_labels
  852. self.perceiver = PerceiverModel(
  853. config,
  854. input_preprocessor=PerceiverTextPreprocessor(config),
  855. decoder=PerceiverClassificationDecoder(
  856. config,
  857. num_channels=config.d_latents,
  858. trainable_position_encoding_kwargs=trainable_position_encoding_kwargs_decoder,
  859. use_query_residual=True,
  860. ),
  861. )
  862. # Initialize weights and apply final processing
  863. self.post_init()
  864. @auto_docstring
  865. def forward(
  866. self,
  867. inputs: Optional[torch.Tensor] = None,
  868. attention_mask: Optional[torch.Tensor] = None,
  869. head_mask: Optional[torch.Tensor] = None,
  870. output_attentions: Optional[bool] = None,
  871. output_hidden_states: Optional[bool] = None,
  872. labels: Optional[torch.Tensor] = None,
  873. return_dict: Optional[bool] = None,
  874. input_ids: Optional[torch.Tensor] = None,
  875. ) -> Union[tuple, PerceiverClassifierOutput]:
  876. r"""
  877. inputs (`torch.FloatTensor`):
  878. Inputs to the perceiver. Can be anything: images, text, audio, video, etc.
  879. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  880. Labels for computing the classification/regression loss. Indices should be in `[0, ..., config.num_labels -
  881. 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels >
  882. 1` a classification loss is computed (Cross-Entropy).
  883. Examples:
  884. ```python
  885. >>> from transformers import AutoTokenizer, PerceiverForSequenceClassification
  886. >>> tokenizer = AutoTokenizer.from_pretrained("deepmind/language-perceiver")
  887. >>> model = PerceiverForSequenceClassification.from_pretrained("deepmind/language-perceiver")
  888. >>> text = "hello world"
  889. >>> inputs = tokenizer(text, return_tensors="pt").input_ids
  890. >>> outputs = model(inputs=inputs)
  891. >>> logits = outputs.logits
  892. >>> list(logits.shape)
  893. [1, 2]
  894. ```"""
  895. if inputs is not None and input_ids is not None:
  896. raise ValueError("You cannot use both `inputs` and `input_ids`")
  897. elif inputs is None and input_ids is not None:
  898. inputs = input_ids
  899. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  900. outputs = self.perceiver(
  901. inputs=inputs,
  902. attention_mask=attention_mask,
  903. head_mask=head_mask,
  904. output_attentions=output_attentions,
  905. output_hidden_states=output_hidden_states,
  906. return_dict=return_dict,
  907. )
  908. logits = outputs.logits if return_dict else outputs[0]
  909. loss = None
  910. if labels is not None:
  911. if self.config.problem_type is None:
  912. if self.num_labels == 1:
  913. self.config.problem_type = "regression"
  914. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  915. self.config.problem_type = "single_label_classification"
  916. else:
  917. self.config.problem_type = "multi_label_classification"
  918. if self.config.problem_type == "regression":
  919. loss_fct = MSELoss()
  920. if self.num_labels == 1:
  921. loss = loss_fct(logits.squeeze(), labels.squeeze())
  922. else:
  923. loss = loss_fct(logits, labels)
  924. elif self.config.problem_type == "single_label_classification":
  925. loss_fct = CrossEntropyLoss()
  926. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  927. elif self.config.problem_type == "multi_label_classification":
  928. loss_fct = BCEWithLogitsLoss()
  929. loss = loss_fct(logits, labels)
  930. if not return_dict:
  931. output = (logits,) + outputs[2:]
  932. return ((loss,) + output) if loss is not None else output
  933. return PerceiverClassifierOutput(
  934. loss=loss,
  935. logits=logits,
  936. hidden_states=outputs.hidden_states,
  937. attentions=outputs.attentions,
  938. cross_attentions=outputs.cross_attentions,
  939. )
  940. @auto_docstring(
  941. custom_intro="""
  942. Example use of Perceiver for image classification, for tasks such as ImageNet.
  943. This model uses learned position embeddings. In other words, this model is not given any privileged information about
  944. the structure of images. As shown in the paper, this model can achieve a top-1 accuracy of 72.7 on ImageNet.
  945. [`PerceiverForImageClassificationLearned`] uses [`~models.perceiver.modeling_perceiver.PerceiverImagePreprocessor`]
  946. (with `prep_type="conv1x1"`) to preprocess the input images, and
  947. [`~models.perceiver.modeling_perceiver.PerceiverClassificationDecoder`] to decode the latent representation of
  948. [`PerceiverModel`] into classification logits.
  949. """
  950. )
  951. class PerceiverForImageClassificationLearned(PerceiverPreTrainedModel):
  952. def __init__(self, config):
  953. super().__init__(config)
  954. trainable_position_encoding_kwargs_preprocessor = {"num_channels": 256, "index_dims": config.image_size**2}
  955. trainable_position_encoding_kwargs_decoder = {"num_channels": config.d_latents, "index_dims": 1}
  956. self.num_labels = config.num_labels
  957. self.perceiver = PerceiverModel(
  958. config,
  959. input_preprocessor=PerceiverImagePreprocessor(
  960. config,
  961. prep_type="conv1x1",
  962. spatial_downsample=1,
  963. out_channels=256,
  964. position_encoding_type="trainable",
  965. concat_or_add_pos="concat",
  966. project_pos_dim=256,
  967. trainable_position_encoding_kwargs=trainable_position_encoding_kwargs_preprocessor,
  968. ),
  969. decoder=PerceiverClassificationDecoder(
  970. config,
  971. num_channels=config.d_latents,
  972. trainable_position_encoding_kwargs=trainable_position_encoding_kwargs_decoder,
  973. use_query_residual=True,
  974. ),
  975. )
  976. # Initialize weights and apply final processing
  977. self.post_init()
  978. @auto_docstring
  979. def forward(
  980. self,
  981. inputs: Optional[torch.Tensor] = None,
  982. attention_mask: Optional[torch.Tensor] = None,
  983. head_mask: Optional[torch.Tensor] = None,
  984. output_attentions: Optional[bool] = None,
  985. output_hidden_states: Optional[bool] = None,
  986. labels: Optional[torch.Tensor] = None,
  987. interpolate_pos_encoding: bool = False,
  988. return_dict: Optional[bool] = None,
  989. pixel_values: Optional[torch.Tensor] = None,
  990. ) -> Union[tuple, PerceiverClassifierOutput]:
  991. r"""
  992. inputs (`torch.FloatTensor`):
  993. Inputs to the perceiver. Can be anything: images, text, audio, video, etc.
  994. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  995. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  996. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  997. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  998. Examples:
  999. ```python
  1000. >>> from transformers import AutoImageProcessor, PerceiverForImageClassificationLearned
  1001. >>> from PIL import Image
  1002. >>> import requests
  1003. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  1004. >>> image = Image.open(requests.get(url, stream=True).raw)
  1005. >>> image_processor = AutoImageProcessor.from_pretrained("deepmind/vision-perceiver-learned")
  1006. >>> model = PerceiverForImageClassificationLearned.from_pretrained("deepmind/vision-perceiver-learned")
  1007. >>> inputs = image_processor(images=image, return_tensors="pt").pixel_values
  1008. >>> outputs = model(inputs=inputs)
  1009. >>> logits = outputs.logits
  1010. >>> list(logits.shape)
  1011. [1, 1000]
  1012. >>> # model predicts one of the 1000 ImageNet classes
  1013. >>> predicted_class_idx = logits.argmax(-1).item()
  1014. >>> print("Predicted class:", model.config.id2label[predicted_class_idx])
  1015. Predicted class: tabby, tabby cat
  1016. ```"""
  1017. if inputs is not None and pixel_values is not None:
  1018. raise ValueError("You cannot use both `inputs` and `pixel_values`")
  1019. elif inputs is None and pixel_values is not None:
  1020. inputs = pixel_values
  1021. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1022. outputs = self.perceiver(
  1023. inputs=inputs,
  1024. attention_mask=attention_mask,
  1025. head_mask=head_mask,
  1026. output_attentions=output_attentions,
  1027. output_hidden_states=output_hidden_states,
  1028. interpolate_pos_encoding=interpolate_pos_encoding,
  1029. return_dict=return_dict,
  1030. )
  1031. logits = outputs.logits if return_dict else outputs[0]
  1032. loss = None
  1033. if labels is not None:
  1034. loss = self.loss_function(labels, logits, self.config)
  1035. if not return_dict:
  1036. output = (logits,) + outputs[2:]
  1037. return ((loss,) + output) if loss is not None else output
  1038. return PerceiverClassifierOutput(
  1039. loss=loss,
  1040. logits=logits,
  1041. hidden_states=outputs.hidden_states,
  1042. attentions=outputs.attentions,
  1043. cross_attentions=outputs.cross_attentions,
  1044. )
  1045. @auto_docstring(
  1046. custom_intro="""
  1047. Example use of Perceiver for image classification, for tasks such as ImageNet.
  1048. This model uses fixed 2D Fourier position embeddings. As shown in the paper, this model can achieve a top-1 accuracy of
  1049. 79.0 on ImageNet, and 84.5 when pre-trained on a large-scale dataset (i.e. JFT).
  1050. [`PerceiverForImageClassificationLearned`] uses [`~models.perceiver.modeling_perceiver.PerceiverImagePreprocessor`]
  1051. (with `prep_type="pixels"`) to preprocess the input images, and
  1052. [`~models.perceiver.modeling_perceiver.PerceiverClassificationDecoder`] to decode the latent representation of
  1053. [`PerceiverModel`] into classification logits.
  1054. """
  1055. )
  1056. class PerceiverForImageClassificationFourier(PerceiverPreTrainedModel):
  1057. def __init__(self, config):
  1058. super().__init__(config)
  1059. fourier_position_encoding_kwargs_preprocessor = {
  1060. "concat_pos": True,
  1061. "max_resolution": (224, 224),
  1062. "num_bands": 64,
  1063. "sine_only": False,
  1064. }
  1065. trainable_position_encoding_kwargs_decoder = {"num_channels": config.d_latents, "index_dims": 1}
  1066. self.num_labels = config.num_labels
  1067. self.perceiver = PerceiverModel(
  1068. config,
  1069. input_preprocessor=PerceiverImagePreprocessor(
  1070. config,
  1071. prep_type="pixels",
  1072. spatial_downsample=1,
  1073. fourier_position_encoding_kwargs=fourier_position_encoding_kwargs_preprocessor,
  1074. ),
  1075. decoder=PerceiverClassificationDecoder(
  1076. config,
  1077. num_channels=config.d_latents,
  1078. trainable_position_encoding_kwargs=trainable_position_encoding_kwargs_decoder,
  1079. use_query_residual=True,
  1080. ),
  1081. )
  1082. # Initialize weights and apply final processing
  1083. self.post_init()
  1084. @auto_docstring
  1085. def forward(
  1086. self,
  1087. inputs: Optional[torch.Tensor] = None,
  1088. attention_mask: Optional[torch.Tensor] = None,
  1089. head_mask: Optional[torch.Tensor] = None,
  1090. output_attentions: Optional[bool] = None,
  1091. output_hidden_states: Optional[bool] = None,
  1092. labels: Optional[torch.Tensor] = None,
  1093. return_dict: Optional[bool] = None,
  1094. pixel_values: Optional[torch.Tensor] = None,
  1095. ) -> Union[tuple, PerceiverClassifierOutput]:
  1096. r"""
  1097. inputs (`torch.FloatTensor`):
  1098. Inputs to the perceiver. Can be anything: images, text, audio, video, etc.
  1099. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1100. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  1101. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  1102. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1103. Examples:
  1104. ```python
  1105. >>> from transformers import AutoImageProcessor, PerceiverForImageClassificationFourier
  1106. >>> from PIL import Image
  1107. >>> import requests
  1108. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  1109. >>> image = Image.open(requests.get(url, stream=True).raw)
  1110. >>> image_processor = AutoImageProcessor.from_pretrained("deepmind/vision-perceiver-fourier")
  1111. >>> model = PerceiverForImageClassificationFourier.from_pretrained("deepmind/vision-perceiver-fourier")
  1112. >>> inputs = image_processor(images=image, return_tensors="pt").pixel_values
  1113. >>> outputs = model(inputs=inputs)
  1114. >>> logits = outputs.logits
  1115. >>> list(logits.shape)
  1116. [1, 1000]
  1117. >>> # model predicts one of the 1000 ImageNet classes
  1118. >>> predicted_class_idx = logits.argmax(-1).item()
  1119. >>> print("Predicted class:", model.config.id2label[predicted_class_idx])
  1120. Predicted class: tabby, tabby cat
  1121. ```"""
  1122. if inputs is not None and pixel_values is not None:
  1123. raise ValueError("You cannot use both `inputs` and `pixel_values`")
  1124. elif inputs is None and pixel_values is not None:
  1125. inputs = pixel_values
  1126. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1127. outputs = self.perceiver(
  1128. inputs=inputs,
  1129. attention_mask=attention_mask,
  1130. head_mask=head_mask,
  1131. output_attentions=output_attentions,
  1132. output_hidden_states=output_hidden_states,
  1133. return_dict=return_dict,
  1134. )
  1135. logits = outputs.logits if return_dict else outputs[0]
  1136. loss = None
  1137. if labels is not None:
  1138. loss = self.loss_function(labels, logits, self.config)
  1139. if not return_dict:
  1140. output = (logits,) + outputs[2:]
  1141. return ((loss,) + output) if loss is not None else output
  1142. return PerceiverClassifierOutput(
  1143. loss=loss,
  1144. logits=logits,
  1145. hidden_states=outputs.hidden_states,
  1146. attentions=outputs.attentions,
  1147. cross_attentions=outputs.cross_attentions,
  1148. )
  1149. @auto_docstring(
  1150. custom_intro="""
  1151. Example use of Perceiver for image classification, for tasks such as ImageNet.
  1152. This model uses a 2D conv+maxpool preprocessing network. As shown in the paper, this model can achieve a top-1 accuracy
  1153. of 82.1 on ImageNet.
  1154. [`PerceiverForImageClassificationLearned`] uses [`~models.perceiver.modeling_perceiver.PerceiverImagePreprocessor`]
  1155. (with `prep_type="conv"`) to preprocess the input images, and
  1156. [`~models.perceiver.modeling_perceiver.PerceiverClassificationDecoder`] to decode the latent representation of
  1157. [`PerceiverModel`] into classification logits.
  1158. """
  1159. )
  1160. class PerceiverForImageClassificationConvProcessing(PerceiverPreTrainedModel):
  1161. def __init__(self, config):
  1162. super().__init__(config)
  1163. fourier_position_encoding_kwargs_preprocessor = {
  1164. "concat_pos": True,
  1165. "max_resolution": (56, 56),
  1166. "num_bands": 64,
  1167. "sine_only": False,
  1168. }
  1169. trainable_position_encoding_kwargs_decoder = {"num_channels": config.d_latents, "index_dims": 1}
  1170. self.num_labels = config.num_labels
  1171. self.perceiver = PerceiverModel(
  1172. config,
  1173. input_preprocessor=PerceiverImagePreprocessor(
  1174. config,
  1175. prep_type="conv",
  1176. spatial_downsample=1,
  1177. position_encoding_type="fourier",
  1178. fourier_position_encoding_kwargs=fourier_position_encoding_kwargs_preprocessor,
  1179. ),
  1180. decoder=PerceiverClassificationDecoder(
  1181. config,
  1182. num_channels=config.d_latents,
  1183. trainable_position_encoding_kwargs=trainable_position_encoding_kwargs_decoder,
  1184. use_query_residual=True,
  1185. ),
  1186. )
  1187. # Initialize weights and apply final processing
  1188. self.post_init()
  1189. @auto_docstring
  1190. def forward(
  1191. self,
  1192. inputs: Optional[torch.Tensor] = None,
  1193. attention_mask: Optional[torch.Tensor] = None,
  1194. head_mask: Optional[torch.Tensor] = None,
  1195. output_attentions: Optional[bool] = None,
  1196. output_hidden_states: Optional[bool] = None,
  1197. labels: Optional[torch.Tensor] = None,
  1198. return_dict: Optional[bool] = None,
  1199. pixel_values: Optional[torch.Tensor] = None,
  1200. ) -> Union[tuple, PerceiverClassifierOutput]:
  1201. r"""
  1202. inputs (`torch.FloatTensor`):
  1203. Inputs to the perceiver. Can be anything: images, text, audio, video, etc.
  1204. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1205. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  1206. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  1207. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1208. Examples:
  1209. ```python
  1210. >>> from transformers import AutoImageProcessor, PerceiverForImageClassificationConvProcessing
  1211. >>> from PIL import Image
  1212. >>> import requests
  1213. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  1214. >>> image = Image.open(requests.get(url, stream=True).raw)
  1215. >>> image_processor = AutoImageProcessor.from_pretrained("deepmind/vision-perceiver-conv")
  1216. >>> model = PerceiverForImageClassificationConvProcessing.from_pretrained("deepmind/vision-perceiver-conv")
  1217. >>> inputs = image_processor(images=image, return_tensors="pt").pixel_values
  1218. >>> outputs = model(inputs=inputs)
  1219. >>> logits = outputs.logits
  1220. >>> list(logits.shape)
  1221. [1, 1000]
  1222. >>> # model predicts one of the 1000 ImageNet classes
  1223. >>> predicted_class_idx = logits.argmax(-1).item()
  1224. >>> print("Predicted class:", model.config.id2label[predicted_class_idx])
  1225. Predicted class: tabby, tabby cat
  1226. ```"""
  1227. if inputs is not None and pixel_values is not None:
  1228. raise ValueError("You cannot use both `inputs` and `pixel_values`")
  1229. elif inputs is None and pixel_values is not None:
  1230. inputs = pixel_values
  1231. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1232. outputs = self.perceiver(
  1233. inputs=inputs,
  1234. attention_mask=attention_mask,
  1235. head_mask=head_mask,
  1236. output_attentions=output_attentions,
  1237. output_hidden_states=output_hidden_states,
  1238. return_dict=return_dict,
  1239. )
  1240. logits = outputs.logits if return_dict else outputs[0]
  1241. loss = None
  1242. if labels is not None:
  1243. loss = self.loss_function(labels, logits, self.config)
  1244. if not return_dict:
  1245. output = (logits,) + outputs[2:]
  1246. return ((loss,) + output) if loss is not None else output
  1247. return PerceiverClassifierOutput(
  1248. loss=loss,
  1249. logits=logits,
  1250. hidden_states=outputs.hidden_states,
  1251. attentions=outputs.attentions,
  1252. cross_attentions=outputs.cross_attentions,
  1253. )
  1254. @auto_docstring(
  1255. custom_intro="""
  1256. Example use of Perceiver for optical flow, for tasks such as Sintel and KITTI. [`PerceiverForOpticalFlow`] uses
  1257. [`~models.perceiver.modeling_perceiver.PerceiverImagePreprocessor`] (with *prep_type="patches"*) to preprocess the
  1258. input images, and [`~models.perceiver.modeling_perceiver.PerceiverOpticalFlowDecoder`] to decode the latent
  1259. representation of [`PerceiverModel`].
  1260. As input, one concatenates 2 subsequent frames along the channel dimension and extract a 3 x 3 patch around each pixel
  1261. (leading to 3 x 3 x 3 x 2 = 54 values for each pixel). Fixed Fourier position encodings are used to encode the position
  1262. of each pixel in the patch. Next, one applies the Perceiver encoder. To decode, one queries the latent representation
  1263. using the same encoding used for the input.
  1264. """
  1265. )
  1266. class PerceiverForOpticalFlow(PerceiverPreTrainedModel):
  1267. def __init__(self, config):
  1268. super().__init__(config)
  1269. fourier_position_encoding_kwargs_preprocessor = {
  1270. "num_bands": 64,
  1271. "max_resolution": config.train_size,
  1272. "sine_only": False,
  1273. "concat_pos": True,
  1274. }
  1275. fourier_position_encoding_kwargs_decoder = {
  1276. "concat_pos": True,
  1277. "max_resolution": config.train_size,
  1278. "num_bands": 64,
  1279. "sine_only": False,
  1280. }
  1281. image_preprocessor = PerceiverImagePreprocessor(
  1282. config,
  1283. prep_type="patches",
  1284. spatial_downsample=1,
  1285. conv_after_patching=True,
  1286. conv_after_patching_in_channels=54,
  1287. temporal_downsample=2,
  1288. position_encoding_type="fourier",
  1289. # position_encoding_kwargs
  1290. fourier_position_encoding_kwargs=fourier_position_encoding_kwargs_preprocessor,
  1291. )
  1292. self.perceiver = PerceiverModel(
  1293. config,
  1294. input_preprocessor=image_preprocessor,
  1295. decoder=PerceiverOpticalFlowDecoder(
  1296. config,
  1297. num_channels=image_preprocessor.num_channels,
  1298. output_image_shape=config.train_size,
  1299. rescale_factor=100.0,
  1300. # decoder kwargs
  1301. use_query_residual=False,
  1302. output_num_channels=2,
  1303. # We query the decoder using the first frame features
  1304. # rather than a standard decoder position encoding.
  1305. position_encoding_type="fourier",
  1306. fourier_position_encoding_kwargs=fourier_position_encoding_kwargs_decoder,
  1307. ),
  1308. )
  1309. # Initialize weights and apply final processing
  1310. self.post_init()
  1311. @auto_docstring
  1312. def forward(
  1313. self,
  1314. inputs: Optional[torch.Tensor] = None,
  1315. attention_mask: Optional[torch.Tensor] = None,
  1316. head_mask: Optional[torch.Tensor] = None,
  1317. output_attentions: Optional[bool] = None,
  1318. output_hidden_states: Optional[bool] = None,
  1319. labels: Optional[torch.Tensor] = None,
  1320. return_dict: Optional[bool] = None,
  1321. ) -> Union[tuple, PerceiverClassifierOutput]:
  1322. r"""
  1323. inputs (`torch.FloatTensor`):
  1324. Inputs to the perceiver. Can be anything: images, text, audio, video, etc.
  1325. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1326. Labels for computing the optical flow loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  1327. Examples:
  1328. ```python
  1329. >>> from transformers import PerceiverForOpticalFlow
  1330. >>> import torch
  1331. >>> model = PerceiverForOpticalFlow.from_pretrained("deepmind/optical-flow-perceiver")
  1332. >>> # in the Perceiver IO paper, the authors extract a 3 x 3 patch around each pixel,
  1333. >>> # leading to 3 x 3 x 3 = 27 values for each pixel (as each pixel also has 3 color channels)
  1334. >>> # patches have shape (batch_size, num_frames, num_channels, height, width)
  1335. >>> # the authors train on resolutions of 368 x 496
  1336. >>> patches = torch.randn(1, 2, 27, 368, 496)
  1337. >>> outputs = model(inputs=patches)
  1338. >>> logits = outputs.logits
  1339. >>> list(logits.shape)
  1340. [1, 368, 496, 2]
  1341. ```"""
  1342. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1343. loss = None
  1344. if labels is not None:
  1345. raise NotImplementedError("Optical flow training is not yet supported")
  1346. outputs = self.perceiver(
  1347. inputs=inputs,
  1348. attention_mask=attention_mask,
  1349. head_mask=head_mask,
  1350. output_attentions=output_attentions,
  1351. output_hidden_states=output_hidden_states,
  1352. return_dict=return_dict,
  1353. )
  1354. logits = outputs.logits if return_dict else outputs[0]
  1355. if not return_dict:
  1356. output = (logits,) + outputs[2:]
  1357. return ((loss,) + output) if loss is not None else output
  1358. return PerceiverClassifierOutput(
  1359. loss=loss,
  1360. logits=logits,
  1361. hidden_states=outputs.hidden_states,
  1362. attentions=outputs.attentions,
  1363. cross_attentions=outputs.cross_attentions,
  1364. )
  1365. @auto_docstring(
  1366. custom_intro="""
  1367. Example use of Perceiver for multimodal (video) autoencoding, for tasks such as Kinetics-700.
  1368. [`PerceiverForMultimodalAutoencoding`] uses [`~models.perceiver.modeling_perceiver.PerceiverMultimodalPreprocessor`] to
  1369. preprocess the 3 modalities: images, audio and class labels. This preprocessor uses modality-specific preprocessors to
  1370. preprocess every modality separately, after which they are concatenated. Trainable position embeddings are used to pad
  1371. each modality to the same number of channels to make concatenation along the time dimension possible. Next, one applies
  1372. the Perceiver encoder.
  1373. [`~models.perceiver.modeling_perceiver.PerceiverMultimodalDecoder`] is used to decode the latent representation of
  1374. [`PerceiverModel`]. This decoder uses each modality-specific decoder to construct queries. The decoder queries are
  1375. created based on the inputs after preprocessing. However, autoencoding an entire video in a single forward pass is
  1376. computationally infeasible, hence one only uses parts of the decoder queries to do cross-attention with the latent
  1377. representation. This is determined by the subsampled indices for each modality, which can be provided as additional
  1378. input to the forward pass of [`PerceiverForMultimodalAutoencoding`].
  1379. [`~models.perceiver.modeling_perceiver.PerceiverMultimodalDecoder`] also pads the decoder queries of the different
  1380. modalities to the same number of channels, in order to concatenate them along the time dimension. Next, cross-attention
  1381. is performed with the latent representation of [`PerceiverModel`].
  1382. Finally, [`~models.perceiver.modeling_perceiver.PerceiverMultiModalPostprocessor`] is used to turn this tensor into an
  1383. actual video. It first splits up the output into the different modalities, and then applies the respective
  1384. postprocessor for each modality.
  1385. Note that, by masking the classification label during evaluation (i.e. simply providing a tensor of zeros for the
  1386. "label" modality), this auto-encoding model becomes a Kinetics 700 video classifier.
  1387. """
  1388. )
  1389. class PerceiverForMultimodalAutoencoding(PerceiverPreTrainedModel):
  1390. def __init__(self, config: PerceiverConfig):
  1391. super().__init__(config)
  1392. n_audio_samples = config.num_frames * config.audio_samples_per_frame
  1393. input_preprocessor = PerceiverMultimodalPreprocessor(
  1394. min_padding_size=4,
  1395. modalities={
  1396. "audio": PerceiverAudioPreprocessor(
  1397. config,
  1398. position_encoding_type="fourier",
  1399. fourier_position_encoding_kwargs={
  1400. "num_bands": 192,
  1401. "max_resolution": (n_audio_samples,),
  1402. "sine_only": False,
  1403. "concat_pos": True,
  1404. },
  1405. prep_type="patches",
  1406. samples_per_patch=config.samples_per_patch,
  1407. ),
  1408. "image": PerceiverImagePreprocessor(
  1409. config,
  1410. position_encoding_type="fourier",
  1411. fourier_position_encoding_kwargs={
  1412. "num_bands": 32,
  1413. "max_resolution": (config.num_frames, config.image_size, config.image_size),
  1414. "sine_only": False,
  1415. "concat_pos": True,
  1416. },
  1417. prep_type="patches",
  1418. spatial_downsample=4,
  1419. temporal_downsample=1,
  1420. ),
  1421. "label": PerceiverOneHotPreprocessor(config),
  1422. },
  1423. mask_probs={"image": 0.0, "audio": 0.0, "label": 1.0},
  1424. )
  1425. image_decoder = PerceiverBasicVideoAutoencodingDecoder(
  1426. config,
  1427. # Autoencoding, don't pass inputs to the queries.
  1428. concat_preprocessed_input=False,
  1429. output_shape=config.output_shape,
  1430. output_num_channels=config.output_num_channels,
  1431. use_query_residual=False,
  1432. position_encoding_only=True,
  1433. position_encoding_type="fourier",
  1434. fourier_position_encoding_kwargs={
  1435. "num_bands": 32,
  1436. "max_resolution": (config.num_frames, config.image_size, config.image_size),
  1437. "sine_only": False,
  1438. "concat_pos": True,
  1439. },
  1440. )
  1441. decoder = PerceiverMultimodalDecoder(
  1442. config,
  1443. # Autoencoding, don't pass inputs to the queries.
  1444. concat_preprocessed_input=False,
  1445. # Modality specific decoders are used ONLY to generate queries.
  1446. # All modalties are decoded together using a unified decoder.
  1447. modalities={
  1448. "audio": PerceiverBasicDecoder(
  1449. config,
  1450. # Autoencoding, don't pass inputs to the queries.
  1451. concat_preprocessed_input=False,
  1452. output_index_dims=(n_audio_samples // config.samples_per_patch,),
  1453. output_num_channels=config.output_num_channels,
  1454. use_query_residual=False,
  1455. position_encoding_only=True,
  1456. position_encoding_type="fourier",
  1457. fourier_position_encoding_kwargs={
  1458. "num_bands": 192,
  1459. "max_resolution": (n_audio_samples,),
  1460. "sine_only": False,
  1461. "concat_pos": True,
  1462. },
  1463. ),
  1464. "image": image_decoder,
  1465. "label": PerceiverClassificationDecoder(
  1466. config,
  1467. # Autoencoding, don't pass inputs to the queries.
  1468. concat_preprocessed_input=False,
  1469. use_query_residual=False,
  1470. position_encoding_only=True,
  1471. position_encoding_type="trainable",
  1472. trainable_position_encoding_kwargs={
  1473. "num_channels": config._label_trainable_num_channels,
  1474. "index_dims": 1,
  1475. },
  1476. ),
  1477. },
  1478. num_outputs=None,
  1479. output_num_channels=config.output_num_channels,
  1480. use_query_residual=False,
  1481. )
  1482. output_postprocessor = PerceiverMultimodalPostprocessor(
  1483. modalities={
  1484. "audio": PerceiverAudioPostprocessor(config, in_channels=config.output_num_channels),
  1485. "image": PerceiverProjectionPostprocessor(in_channels=config.output_num_channels, out_channels=3),
  1486. "label": PerceiverClassificationPostprocessor(config, in_channels=config.output_num_channels),
  1487. }
  1488. )
  1489. self.perceiver = PerceiverModel(
  1490. config,
  1491. input_preprocessor=input_preprocessor,
  1492. decoder=decoder,
  1493. output_postprocessor=output_postprocessor,
  1494. )
  1495. # Initialize weights and apply final processing
  1496. self.post_init()
  1497. @auto_docstring
  1498. def forward(
  1499. self,
  1500. inputs: Optional[torch.Tensor] = None,
  1501. attention_mask: Optional[torch.Tensor] = None,
  1502. subsampled_output_points: Optional[dict[str, torch.Tensor]] = None,
  1503. head_mask: Optional[torch.Tensor] = None,
  1504. output_attentions: Optional[bool] = None,
  1505. output_hidden_states: Optional[bool] = None,
  1506. labels: Optional[torch.Tensor] = None,
  1507. return_dict: Optional[bool] = None,
  1508. ) -> Union[tuple, PerceiverClassifierOutput]:
  1509. r"""
  1510. inputs (`torch.FloatTensor`):
  1511. Inputs to the perceiver. Can be anything: images, text, audio, video, etc.
  1512. subsampled_output_points (`dict[str, torch.Tensor]`, *optional*):
  1513. Dictionary of tensors used as queries for the decoder. The decoder maps these queries to the latent
  1514. representation of the model. Used for subsampled decoding, e.g. when only decoding certain image patches.
  1515. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1516. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  1517. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  1518. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1519. Examples:
  1520. ```python
  1521. >>> from transformers import PerceiverForMultimodalAutoencoding
  1522. >>> import torch
  1523. >>> import numpy as np
  1524. >>> # create multimodal inputs
  1525. >>> images = torch.randn((1, 16, 3, 224, 224))
  1526. >>> audio = torch.randn((1, 30720, 1))
  1527. >>> inputs = dict(image=images, audio=audio, label=torch.zeros((images.shape[0], 700)))
  1528. >>> model = PerceiverForMultimodalAutoencoding.from_pretrained("deepmind/multimodal-perceiver")
  1529. >>> # in the Perceiver IO paper, videos are auto-encoded in chunks
  1530. >>> # each chunk subsamples different index dimensions of the image and audio modality decoder queries
  1531. >>> nchunks = 128
  1532. >>> image_chunk_size = np.prod((16, 224, 224)) // nchunks
  1533. >>> audio_chunk_size = audio.shape[1] // model.config.samples_per_patch // nchunks
  1534. >>> # process the first chunk
  1535. >>> chunk_idx = 0
  1536. >>> subsampling = {
  1537. ... "image": torch.arange(image_chunk_size * chunk_idx, image_chunk_size * (chunk_idx + 1)),
  1538. ... "audio": torch.arange(audio_chunk_size * chunk_idx, audio_chunk_size * (chunk_idx + 1)),
  1539. ... "label": None,
  1540. ... }
  1541. >>> outputs = model(inputs=inputs, subsampled_output_points=subsampling)
  1542. >>> logits = outputs.logits
  1543. >>> list(logits["audio"].shape)
  1544. [1, 240]
  1545. >>> list(logits["image"].shape)
  1546. [1, 6272, 3]
  1547. >>> list(logits["label"].shape)
  1548. [1, 700]
  1549. ```"""
  1550. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1551. loss = None
  1552. if labels is not None:
  1553. raise NotImplementedError("Multimodal autoencoding training is not yet supported")
  1554. outputs = self.perceiver(
  1555. inputs=inputs,
  1556. attention_mask=attention_mask,
  1557. subsampled_output_points=subsampled_output_points,
  1558. head_mask=head_mask,
  1559. output_attentions=output_attentions,
  1560. output_hidden_states=output_hidden_states,
  1561. return_dict=return_dict,
  1562. )
  1563. logits = outputs.logits if return_dict else outputs[0]
  1564. if not return_dict:
  1565. output = (logits,) + outputs[2:]
  1566. return ((loss,) + output) if loss is not None else output
  1567. return PerceiverClassifierOutput(
  1568. loss=loss,
  1569. logits=logits,
  1570. hidden_states=outputs.hidden_states,
  1571. attentions=outputs.attentions,
  1572. cross_attentions=outputs.cross_attentions,
  1573. )
  1574. # Below: position encodings
  1575. def build_position_encoding(
  1576. position_encoding_type,
  1577. out_channels=None,
  1578. project_pos_dim=-1,
  1579. trainable_position_encoding_kwargs=None,
  1580. fourier_position_encoding_kwargs=None,
  1581. ):
  1582. """
  1583. Builds the position encoding.
  1584. Args:
  1585. - out_channels: refers to the number of channels of the position encodings.
  1586. - project_pos_dim: if specified, will project the position encodings to this dimension.
  1587. """
  1588. if position_encoding_type == "trainable":
  1589. if not trainable_position_encoding_kwargs:
  1590. raise ValueError("Make sure to pass trainable_position_encoding_kwargs")
  1591. output_pos_enc = PerceiverTrainablePositionEncoding(**trainable_position_encoding_kwargs)
  1592. elif position_encoding_type == "fourier":
  1593. # We don't use the index_dims argument, as this is only known during the forward pass
  1594. if not fourier_position_encoding_kwargs:
  1595. raise ValueError("Make sure to pass fourier_position_encoding_kwargs")
  1596. output_pos_enc = PerceiverFourierPositionEncoding(**fourier_position_encoding_kwargs)
  1597. else:
  1598. raise ValueError(f"Unknown position encoding type: {position_encoding_type}.")
  1599. # Optionally, project the position encoding to a target dimension:
  1600. positions_projection = nn.Linear(out_channels, project_pos_dim) if project_pos_dim > 0 else nn.Identity()
  1601. return output_pos_enc, positions_projection
  1602. # Below: Perceiver decoders
  1603. class PerceiverAbstractDecoder(nn.Module, metaclass=abc.ABCMeta):
  1604. """Perceiver abstract decoder."""
  1605. @abc.abstractmethod
  1606. def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, subsampled_points=None):
  1607. raise NotImplementedError
  1608. @property
  1609. @abc.abstractmethod
  1610. def num_query_channels(self):
  1611. raise NotImplementedError
  1612. @abc.abstractmethod
  1613. def forward(self, query, z, query_mask=None):
  1614. raise NotImplementedError
  1615. class PerceiverProjectionDecoder(PerceiverAbstractDecoder):
  1616. """
  1617. Baseline projection decoder (no cross-attention).
  1618. Args:
  1619. config ([`PerceiverConfig`]):
  1620. Model configuration.
  1621. """
  1622. def __init__(self, config):
  1623. super().__init__()
  1624. self.classifier = nn.Linear(config.d_latents, config.num_labels)
  1625. def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, subsampled_points=None):
  1626. return None
  1627. def forward(
  1628. self, query: torch.Tensor, z: torch.FloatTensor, query_mask: Optional[torch.FloatTensor] = None
  1629. ) -> torch.FloatTensor:
  1630. # (batch_size, num_latents, d_latents) -> (batch_size, d_latents)
  1631. z = torch.mean(z, dim=1)
  1632. # (batch_size, d_latents) -> (batch_size, config.num_labels)
  1633. logits = self.classifier(z)
  1634. return logits
  1635. class PerceiverBasicDecoder(PerceiverAbstractDecoder):
  1636. """
  1637. Cross-attention-based decoder. This class can be used to decode the final hidden states of the latents using a
  1638. cross-attention operation, in which the latents produce keys and values.
  1639. The shape of the output of this class depends on how one defines the output queries (also called decoder queries).
  1640. Args:
  1641. config ([*PerceiverConfig*]):
  1642. Model configuration.
  1643. output_num_channels (`int`, *optional*):
  1644. The number of channels in the output. Will only be used in case *final_project* is set to `True`.
  1645. position_encoding_type (`str`, *optional*, defaults to "trainable"):
  1646. The type of position encoding to use. Can be either "trainable", "fourier", or "none".
  1647. output_index_dims (`int`, *optional*):
  1648. The number of dimensions of the output queries. Ignored if 'position_encoding_type' == 'none'.
  1649. num_channels (`int`, *optional*, defaults to 128):
  1650. The number of channels of the decoder queries. Ignored if 'position_encoding_type' == 'none'.
  1651. qk_channels (`int`, *optional*):
  1652. The number of channels of the queries and keys in the cross-attention layer.
  1653. v_channels (`int`, *optional*):
  1654. The number of channels of the values in the cross-attention layer.
  1655. num_heads (`int`, *optional*, defaults to 1):
  1656. The number of attention heads in the cross-attention layer.
  1657. widening_factor (`int`, *optional*, defaults to 1):
  1658. The widening factor of the cross-attention layer.
  1659. use_query_residual (`bool`, *optional*, defaults to `False`):
  1660. Whether to use a residual connection between the query and the output of the cross-attention layer.
  1661. concat_preprocessed_input (`bool`, *optional*, defaults to `False`):
  1662. Whether to concatenate the preprocessed input to the query.
  1663. final_project (`bool`, *optional*, defaults to `True`):
  1664. Whether to project the output of the cross-attention layer to a target dimension.
  1665. position_encoding_only (`bool`, *optional*, defaults to `False`):
  1666. Whether to only use this class to define output queries.
  1667. """
  1668. def __init__(
  1669. self,
  1670. config: PerceiverConfig,
  1671. output_num_channels: int,
  1672. position_encoding_type: Optional[str] = "trainable",
  1673. # The following 2 arguments are ignored if position_encoding_type == 'none':
  1674. output_index_dims: Optional[int] = None,
  1675. num_channels: Optional[int] = 128,
  1676. subsampled_index_dims: Optional[int] = None,
  1677. qk_channels: Optional[int] = None,
  1678. v_channels: Optional[int] = None,
  1679. num_heads: Optional[int] = 1,
  1680. widening_factor: Optional[int] = 1,
  1681. use_query_residual: Optional[bool] = False,
  1682. concat_preprocessed_input: Optional[bool] = False,
  1683. final_project: Optional[bool] = True,
  1684. position_encoding_only: Optional[bool] = False,
  1685. **position_encoding_kwargs,
  1686. ) -> None:
  1687. super().__init__()
  1688. self.output_num_channels = output_num_channels
  1689. # If `none`, the decoder will not construct any position encodings.
  1690. # You should construct your own when querying the decoder.
  1691. self.output_position_encodings = None
  1692. self.position_encoding_type = position_encoding_type
  1693. self.position_encoding_kwargs = position_encoding_kwargs
  1694. if position_encoding_type != "none":
  1695. self.output_position_encodings, self.positions_projection = build_position_encoding(
  1696. position_encoding_type=position_encoding_type, **position_encoding_kwargs
  1697. )
  1698. self.output_index_dims = output_index_dims
  1699. self.num_channels = num_channels
  1700. if subsampled_index_dims is None:
  1701. subsampled_index_dims = output_index_dims
  1702. self.subsampled_index_dims = subsampled_index_dims
  1703. self.concat_preprocessed_input = concat_preprocessed_input
  1704. self.final_project = final_project
  1705. self.position_encoding_only = position_encoding_only
  1706. # for multimodal autoencoding, we don't need the decoder cross-attention and final layer
  1707. # so then we will set position_encoding_only to True
  1708. if not self.position_encoding_only:
  1709. self.decoding_cross_attention = PerceiverLayer(
  1710. config,
  1711. is_cross_attention=True,
  1712. qk_channels=qk_channels,
  1713. v_channels=v_channels,
  1714. num_heads=num_heads,
  1715. q_dim=num_channels,
  1716. kv_dim=config.d_latents,
  1717. widening_factor=widening_factor,
  1718. use_query_residual=use_query_residual,
  1719. )
  1720. self.final_layer = nn.Linear(num_channels, output_num_channels) if final_project else nn.Identity()
  1721. @property
  1722. def num_query_channels(self) -> int:
  1723. if self.position_encoding_type == "none": # Queries come from elsewhere
  1724. raise ValueError(
  1725. "You cannot calculate number of decoder query channels when position_encoding_type is set to none"
  1726. )
  1727. if self.position_encoding_only:
  1728. if "project_pos_dim" in self.position_encoding_kwargs:
  1729. return self.position_encoding_kwargs["project_pos_dim"]
  1730. return self.output_position_encodings.output_size()
  1731. if self.final_project:
  1732. return self.output_num_channels
  1733. return self.num_channels
  1734. def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, subsampled_points=None):
  1735. if self.position_encoding_type == "none": # Queries come from elsewhere
  1736. raise ValueError("You cannot construct decoder queries when position_encoding_type is set to none")
  1737. if subsampled_points is not None:
  1738. # subsampled_points are the indices if the inputs would be flattened
  1739. # however, the inputs aren't flattened, that's why we use unravel_index
  1740. # to get the indices for the unflattened array
  1741. # unravel_index returns a tuple (x_idx, y_idx, ...)
  1742. # stack to get the [n, d] tensor of coordinates
  1743. indices = [torch.from_numpy(x) for x in np.unravel_index(subsampled_points.cpu(), self.output_index_dims)]
  1744. pos = torch.stack(indices, dim=1)
  1745. batch_size = inputs.shape[0]
  1746. # Map these coordinates to [-1, 1]
  1747. pos = -1 + 2 * pos / torch.tensor(self.output_index_dims)[None, :]
  1748. pos = torch.broadcast_to(pos[None], [batch_size, pos.shape[0], pos.shape[1]])
  1749. # Construct the position encoding.
  1750. if self.position_encoding_type == "trainable":
  1751. pos_emb = self.output_position_encodings(batch_size)
  1752. elif self.position_encoding_type == "fourier":
  1753. pos_emb = self.output_position_encodings(
  1754. self.output_index_dims, batch_size=batch_size, device=inputs.device, dtype=inputs.dtype, pos=pos
  1755. )
  1756. # Optionally project them to a target dimension.
  1757. pos_emb = self.positions_projection(pos_emb)
  1758. pos_emb = torch.reshape(pos_emb, [pos_emb.shape[0], -1, pos_emb.shape[-1]])
  1759. else:
  1760. batch_size = inputs.shape[0]
  1761. index_dims = inputs.shape[2:]
  1762. # Construct the position encoding.
  1763. if self.position_encoding_type == "trainable":
  1764. pos_emb = self.output_position_encodings(batch_size)
  1765. elif self.position_encoding_type == "fourier":
  1766. pos_emb = self.output_position_encodings(
  1767. index_dims, batch_size, device=inputs.device, dtype=inputs.dtype
  1768. )
  1769. # Optionally project them to a target dimension.
  1770. pos_emb = self.positions_projection(pos_emb)
  1771. if self.concat_preprocessed_input:
  1772. if inputs_without_pos is None:
  1773. raise ValueError("Value is required for inputs_without_pos if concat_preprocessed_input is True")
  1774. pos_emb = torch.cat([inputs_without_pos, pos_emb], dim=-1)
  1775. return pos_emb
  1776. def forward(
  1777. self,
  1778. query: torch.Tensor,
  1779. z: torch.FloatTensor,
  1780. query_mask: Optional[torch.FloatTensor] = None,
  1781. output_attentions: Optional[bool] = False,
  1782. ) -> PerceiverDecoderOutput:
  1783. # Cross-attention decoding.
  1784. # key, value: B x N x K; query: B x M x K
  1785. # Attention maps -> B x N x M
  1786. # Output -> B x M x K
  1787. cross_attentions = () if output_attentions else None
  1788. layer_outputs = self.decoding_cross_attention(
  1789. query,
  1790. attention_mask=query_mask,
  1791. head_mask=None,
  1792. inputs=z,
  1793. inputs_mask=None,
  1794. output_attentions=output_attentions,
  1795. )
  1796. output = layer_outputs[0]
  1797. if output_attentions:
  1798. cross_attentions = cross_attentions + (layer_outputs[1],)
  1799. logits = self.final_layer(output)
  1800. return PerceiverDecoderOutput(logits=logits, cross_attentions=cross_attentions)
  1801. class PerceiverClassificationDecoder(PerceiverAbstractDecoder):
  1802. """
  1803. Cross-attention based classification decoder. Light-weight wrapper of [`PerceiverBasicDecoder`] for logit output.
  1804. Will turn the output of the Perceiver encoder which is of shape (batch_size, num_latents, d_latents) to a tensor of
  1805. shape (batch_size, num_labels). The queries are of shape (batch_size, 1, num_labels).
  1806. Args:
  1807. config ([`PerceiverConfig`]):
  1808. Model configuration.
  1809. """
  1810. def __init__(self, config, **decoder_kwargs):
  1811. super().__init__()
  1812. self.num_labels = config.num_labels
  1813. self.decoder = PerceiverBasicDecoder(
  1814. config,
  1815. output_num_channels=self.num_labels,
  1816. output_index_dims=1, # Predict a single logit array.
  1817. **decoder_kwargs,
  1818. )
  1819. @property
  1820. def num_query_channels(self) -> int:
  1821. return self.decoder.num_query_channels
  1822. def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, subsampled_points=None):
  1823. return self.decoder.decoder_query(
  1824. inputs, modality_sizes, inputs_without_pos, subsampled_points=subsampled_points
  1825. )
  1826. def forward(
  1827. self,
  1828. query: torch.Tensor,
  1829. z: torch.FloatTensor,
  1830. query_mask: Optional[torch.FloatTensor] = None,
  1831. output_attentions: Optional[bool] = False,
  1832. ) -> PerceiverDecoderOutput:
  1833. decoder_outputs = self.decoder(query, z, output_attentions=output_attentions)
  1834. # B x 1 x num_classes -> B x num_classes
  1835. logits = decoder_outputs.logits[:, 0, :]
  1836. return PerceiverDecoderOutput(logits=logits, cross_attentions=decoder_outputs.cross_attentions)
  1837. class PerceiverOpticalFlowDecoder(PerceiverAbstractDecoder):
  1838. """Cross-attention based optical flow decoder."""
  1839. def __init__(self, config, output_image_shape, output_num_channels=2, rescale_factor=100.0, **decoder_kwargs):
  1840. super().__init__()
  1841. self.output_image_shape = output_image_shape
  1842. self.output_num_channels = output_num_channels
  1843. self.rescale_factor = rescale_factor
  1844. self.decoder = PerceiverBasicDecoder(config, output_num_channels=output_num_channels, **decoder_kwargs)
  1845. @property
  1846. def num_query_channels(self) -> int:
  1847. return self.decoder.num_query_channels
  1848. def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, subsampled_points=None):
  1849. if subsampled_points is not None:
  1850. raise ValueError("FlowDecoder doesn't support subsampling yet.")
  1851. return inputs
  1852. def forward(
  1853. self,
  1854. query: torch.Tensor,
  1855. z: torch.FloatTensor,
  1856. query_mask: Optional[torch.FloatTensor] = None,
  1857. output_attentions: Optional[bool] = False,
  1858. ) -> PerceiverDecoderOutput:
  1859. decoder_outputs = self.decoder(query, z, output_attentions=output_attentions)
  1860. preds = decoder_outputs.logits
  1861. # Output flow and rescale.
  1862. preds /= self.rescale_factor
  1863. preds = preds.reshape([preds.shape[0]] + list(self.output_image_shape) + [preds.shape[-1]])
  1864. return PerceiverDecoderOutput(logits=preds, cross_attentions=decoder_outputs.cross_attentions)
  1865. class PerceiverBasicVideoAutoencodingDecoder(PerceiverAbstractDecoder):
  1866. """
  1867. Cross-attention based video-autoencoding decoder. Light-weight wrapper of [*PerceiverBasicDecoder*] with video
  1868. reshaping logic.
  1869. Args:
  1870. config ([*PerceiverConfig*]):
  1871. Model configuration.
  1872. output_shape (`list[int]`):
  1873. Shape of the output as (batch_size, num_frames, height, width), excluding the channel dimension.
  1874. position_encoding_type (`str`):
  1875. The type of position encoding to use. Can be either "trainable", "fourier", or "none".
  1876. """
  1877. def __init__(
  1878. self, config: PerceiverConfig, output_shape: list[int], position_encoding_type: str, **decoder_kwargs
  1879. ) -> None:
  1880. super().__init__()
  1881. if len(output_shape) != 4: # B, T, H, W
  1882. raise ValueError(f"Expected rank 4 output_shape, got {output_shape}.")
  1883. # Build the decoder components:
  1884. self.output_shape = output_shape
  1885. self.output_num_channels = decoder_kwargs["output_num_channels"]
  1886. self.decoder = PerceiverBasicDecoder(
  1887. config,
  1888. output_index_dims=self.output_shape[1:4], # T*H*W
  1889. position_encoding_type=position_encoding_type,
  1890. **decoder_kwargs,
  1891. )
  1892. @property
  1893. def num_query_channels(self) -> int:
  1894. return self.decoder.num_query_channels
  1895. def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, subsampled_points=None):
  1896. return self.decoder.decoder_query(
  1897. inputs,
  1898. modality_sizes=modality_sizes,
  1899. inputs_without_pos=inputs_without_pos,
  1900. subsampled_points=subsampled_points,
  1901. )
  1902. def forward(
  1903. self, query: torch.Tensor, z: torch.FloatTensor, query_mask: Optional[torch.FloatTensor] = None
  1904. ) -> PerceiverDecoderOutput:
  1905. decoder_outputs = self.decoder(query, z)
  1906. logits = decoder_outputs.logits
  1907. logits = torch.reshape(logits, self.output_shape + [logits.shape[-1]])
  1908. return PerceiverDecoderOutput(logits=logits, cross_attentions=decoder_outputs.cross_attentions)
  1909. def restructure(modality_sizes: ModalitySizeType, inputs: torch.Tensor) -> Mapping[str, torch.Tensor]:
  1910. """
  1911. Partitions a [B, N, C] tensor into tensors for each modality.
  1912. Args:
  1913. modality_sizes
  1914. dict specifying the size of the modality
  1915. inputs:
  1916. input tensor
  1917. Returns:
  1918. dict mapping name of modality to its associated tensor.
  1919. """
  1920. outputs = {}
  1921. index = 0
  1922. # Apply a predictable ordering to the modalities
  1923. for modality in sorted(modality_sizes.keys()):
  1924. size = modality_sizes[modality]
  1925. inp = inputs[:, index : index + size]
  1926. index += size
  1927. outputs[modality] = inp
  1928. return outputs
  1929. class PerceiverMultimodalDecoder(PerceiverAbstractDecoder):
  1930. """
  1931. Multimodal decoding by composing uni-modal decoders. The *modalities* argument of the constructor is a dictionary
  1932. mapping modality name to the decoder of that modality. That decoder will be used to construct queries for that
  1933. modality. Modality-specific queries are padded with trainable modality-specific parameters, after which they are
  1934. concatenated along the time dimension.
  1935. Next, there is a shared cross attention operation across all modalities.
  1936. Args:
  1937. config ([*PerceiverConfig*]):
  1938. Model configuration.
  1939. modalities (`dict[str, PerceiverAbstractDecoder]`):
  1940. Dictionary mapping modality name to the decoder of that modality.
  1941. num_outputs (`int`):
  1942. The number of outputs of the decoder.
  1943. output_num_channels (`int`):
  1944. The number of channels in the output.
  1945. min_padding_size (`int`, *optional*, defaults to 2):
  1946. The minimum padding size for all modalities. The final output will have num_channels equal to the maximum
  1947. channels across all modalities plus min_padding_size.
  1948. subsampled_index_dims (`dict[str, PerceiverAbstractDecoder]`, *optional*):
  1949. Dictionary mapping modality name to the subsampled index dimensions to use for the decoder query of that
  1950. modality.
  1951. """
  1952. def __init__(
  1953. self,
  1954. config: PerceiverConfig,
  1955. modalities: dict[str, PerceiverAbstractDecoder],
  1956. num_outputs: int,
  1957. output_num_channels: int,
  1958. min_padding_size: Optional[int] = 2,
  1959. subsampled_index_dims: Optional[dict[str, PerceiverAbstractDecoder]] = None,
  1960. **decoder_kwargs,
  1961. ) -> None:
  1962. super().__init__()
  1963. self.modalities = nn.ModuleDict(modalities)
  1964. self.subsampled_index_dims = subsampled_index_dims
  1965. self.min_padding_size = min_padding_size
  1966. self.output_num_channels = output_num_channels
  1967. self.num_outputs = num_outputs
  1968. self.decoder = PerceiverBasicDecoder(
  1969. config,
  1970. output_index_dims=(num_outputs,),
  1971. output_num_channels=output_num_channels,
  1972. position_encoding_type="none",
  1973. num_channels=self.num_query_channels,
  1974. **decoder_kwargs,
  1975. )
  1976. self.padding = nn.ParameterDict(
  1977. {
  1978. modality: nn.Parameter(torch.randn(1, self.num_query_channels - decoder.num_query_channels))
  1979. for modality, decoder in modalities.items()
  1980. }
  1981. )
  1982. @property
  1983. def num_query_channels(self) -> int:
  1984. max_channel_size = max(decoder.num_query_channels for _, decoder in self.modalities.items())
  1985. common_channel_size = max_channel_size + self.min_padding_size
  1986. return common_channel_size
  1987. def decoder_query(self, inputs, modality_sizes, inputs_without_pos=None, subsampled_points=None):
  1988. # Partition the flat inputs among the different modalities
  1989. inputs = restructure(modality_sizes, inputs)
  1990. # Obtain modality-specific decoders' queries
  1991. subsampled_points = subsampled_points or {}
  1992. decoder_queries = {}
  1993. for modality, decoder in self.modalities.items():
  1994. # Get input_without_pos for this modality if it exists.
  1995. input_without_pos = None
  1996. if inputs_without_pos is not None:
  1997. input_without_pos = inputs_without_pos.get(modality, None)
  1998. query = decoder.decoder_query(
  1999. inputs=inputs[modality],
  2000. modality_sizes=None,
  2001. inputs_without_pos=input_without_pos,
  2002. subsampled_points=subsampled_points.get(modality, None),
  2003. )
  2004. decoder_queries[modality] = query
  2005. # Pad all queries with trainable position encodings to make them have the same channels
  2006. def embed(modality, x):
  2007. x = torch.reshape(x, [x.shape[0], np.prod(x.shape[1:-1]), x.shape[-1]])
  2008. pos = self.padding[modality]
  2009. pos = torch.broadcast_to(pos, [x.shape[0], x.shape[1], self.num_query_channels - x.shape[2]])
  2010. return torch.cat([x, pos], dim=2)
  2011. # Apply a predictable ordering to the modalities
  2012. return torch.cat(
  2013. [embed(modality, decoder_queries[modality]) for modality in sorted(self.modalities.keys())], dim=1
  2014. )
  2015. def forward(
  2016. self,
  2017. query: torch.Tensor,
  2018. z: torch.FloatTensor,
  2019. query_mask: Optional[torch.FloatTensor] = None,
  2020. output_attentions: Optional[bool] = False,
  2021. ) -> torch.Tensor:
  2022. # B x 1 x num_classes -> B x num_classes
  2023. decoder_outputs = self.decoder(query, z, output_attentions=output_attentions)
  2024. return decoder_outputs
  2025. # Below: IO pre- and post-processor classes for Perceiver.
  2026. def space_to_depth(frames: torch.Tensor, temporal_block_size: int = 1, spatial_block_size: int = 1) -> torch.Tensor:
  2027. """
  2028. Space to depth transform. Rearranges blocks of spatial data, into depth.
  2029. This function assumes the channels to be first, but will place the channels last after transformation.
  2030. Based on https://discuss.pytorch.org/t/is-there-any-layer-like-tensorflows-space-to-depth-function/3487/15.
  2031. """
  2032. if len(frames.shape) == 4:
  2033. batch_size, num_channels, height, width = frames.shape
  2034. # split up dimensions (height by spatial_block_size, width by spatial_block_size)
  2035. frames = frames.view(
  2036. batch_size,
  2037. num_channels,
  2038. height // spatial_block_size,
  2039. spatial_block_size,
  2040. width // spatial_block_size,
  2041. spatial_block_size,
  2042. )
  2043. # move blocks to last dimension: (batch_size, H//bs, W//bs, bs, bs, C)
  2044. frames = frames.permute(0, 2, 4, 3, 5, 1).contiguous()
  2045. # concatenate blocks along channel dimension: (batch_size, H//bs, W//bs, bs*bs*C)
  2046. frames = frames.view(
  2047. batch_size,
  2048. height // spatial_block_size,
  2049. width // spatial_block_size,
  2050. (spatial_block_size**2) * num_channels,
  2051. )
  2052. return frames
  2053. elif len(frames.shape) == 5:
  2054. batch_size, time, num_channels, height, width = frames.shape
  2055. # split up dimensions (time by temporal_block_size, height by spatial_block_size, width by spatial_block_size)
  2056. frames = frames.view(
  2057. batch_size,
  2058. time // temporal_block_size,
  2059. temporal_block_size,
  2060. num_channels,
  2061. height // spatial_block_size,
  2062. spatial_block_size,
  2063. width // spatial_block_size,
  2064. spatial_block_size,
  2065. )
  2066. # move blocks to last dimension: (batch_size, T//ts, H//bs, W//bs, ts, bs, bs, C)
  2067. frames = frames.permute(0, 1, 4, 6, 2, 5, 7, 3).contiguous()
  2068. # concatenate blocks along channel dimension: (batch_size, T//ts, H//bs, W//bs, ts*bs*bs*C)
  2069. frames = frames.view(
  2070. batch_size,
  2071. time // temporal_block_size,
  2072. height // spatial_block_size,
  2073. width // spatial_block_size,
  2074. temporal_block_size * (spatial_block_size**2) * num_channels,
  2075. )
  2076. return frames
  2077. else:
  2078. raise ValueError(
  2079. "Frames should be of rank 4 (batch, channels, height, width)"
  2080. " or rank 5 (batch, time, channels, height, width)"
  2081. )
  2082. class Conv2dSamePadding(nn.Conv2d):
  2083. """
  2084. Conv2d layer with padding="same" support. Source:
  2085. https://gist.github.com/sumanmichael/4de9dee93f972d47c80c4ade8e149ea6
  2086. """
  2087. def __init__(self, *args, **kwargs):
  2088. super().__init__(*args, **kwargs)
  2089. self.zero_pad_2d = nn.ZeroPad2d(
  2090. reduce(__add__, [(k // 2 + (k - 2 * (k // 2)) - 1, k // 2) for k in self.kernel_size[::-1]])
  2091. )
  2092. def forward(self, input):
  2093. return self._conv_forward(self.zero_pad_2d(input), self.weight, self.bias)
  2094. class Conv2DDownsample(nn.Module):
  2095. """Downsamples 4x by applying a 2D convolution and doing max pooling."""
  2096. def __init__(
  2097. self,
  2098. num_layers: int = 1,
  2099. in_channels: int = 3,
  2100. out_channels: int = 64,
  2101. use_batchnorm: bool = True,
  2102. ):
  2103. """
  2104. Constructs a Conv2DDownsample model.
  2105. Args:
  2106. in_channels (`int`, *optional*, defaults to 3):
  2107. The number of input channels.
  2108. out_channels (`int`, *optional*, defaults to 64):
  2109. The number of conv output channels.
  2110. use_batchnorm (`bool`, *optional*, defaults to `True`):
  2111. Whether to use batchnorm.
  2112. """
  2113. super().__init__()
  2114. self.conv = Conv2dSamePadding(
  2115. in_channels=in_channels, out_channels=out_channels, kernel_size=7, stride=2, bias=False
  2116. )
  2117. self.batchnorm = nn.BatchNorm2d(num_features=out_channels) if use_batchnorm else nn.Identity()
  2118. self.relu = nn.ReLU()
  2119. self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2)
  2120. def forward(self, inputs: torch.Tensor) -> torch.Tensor:
  2121. out = self.conv(inputs)
  2122. out = self.batchnorm(out)
  2123. out = self.relu(out)
  2124. out = self.max_pool(out)
  2125. return out
  2126. def generate_fourier_features(pos, num_bands, max_resolution=(224, 224), concat_pos=True, sine_only=False):
  2127. """
  2128. Generate a Fourier frequency position encoding with linear spacing.
  2129. Args:
  2130. pos (`torch.LongTensor` of shape `(batch_size, sequence_length, dim)`):
  2131. The Tensor containing the position of n points in d dimensional space.
  2132. num_bands (`int`):
  2133. The number of frequency bands (K) to use.
  2134. max_resolution (`tuple[int]`, *optional*, defaults to (224, 224)):
  2135. The maximum resolution (i.e. the number of pixels per dim). A tuple representing resolution for each dimension.
  2136. concat_pos (`bool`, *optional*, defaults to `True`):
  2137. Whether to concatenate the input position encoding to the Fourier features.
  2138. sine_only (`bool`, *optional*, defaults to `False`):
  2139. Whether to use a single phase (sin) or two (sin/cos) for each frequency band.
  2140. Returns:
  2141. `torch.FloatTensor` of shape `(batch_size, sequence_length, n_channels)`: The Fourier position embeddings. If
  2142. `concat_pos` is `True` and `sine_only` is `False`, output dimensions are ordered as: [dim_1, dim_2, ..., dim_d,
  2143. sin(pi*f_1*dim_1), ..., sin(pi*f_K*dim_1), ..., sin(pi*f_1*dim_d), ..., sin(pi*f_K*dim_d), cos(pi*f_1*dim_1),
  2144. ..., cos(pi*f_K*dim_1), ..., cos(pi*f_1*dim_d), ..., cos(pi*f_K*dim_d)], where dim_i is pos[:, i] and f_k is the
  2145. kth frequency band.
  2146. """
  2147. batch_size = pos.shape[0]
  2148. min_freq = 1.0
  2149. # Nyquist frequency at the target resolution:
  2150. freq_bands = torch.stack(
  2151. [torch.linspace(start=min_freq, end=res / 2, steps=num_bands) for res in max_resolution], dim=0
  2152. )
  2153. # Get frequency bands for each spatial dimension.
  2154. # Output is size [n, d * num_bands]
  2155. per_pos_features = pos[0, :, :][:, :, None] * freq_bands[None, :, :]
  2156. per_pos_features = torch.reshape(per_pos_features, [-1, np.prod(per_pos_features.shape[1:])])
  2157. if sine_only:
  2158. # Output is size [n, d * num_bands]
  2159. per_pos_features = torch.sin(np.pi * (per_pos_features))
  2160. else:
  2161. # Output is size [n, 2 * d * num_bands]
  2162. per_pos_features = torch.cat(
  2163. [torch.sin(np.pi * per_pos_features), torch.cos(np.pi * per_pos_features)], dim=-1
  2164. )
  2165. # Concatenate the raw input positions.
  2166. if concat_pos:
  2167. # Adds d bands to the encoding.
  2168. per_pos_features = torch.cat([pos, per_pos_features.expand(batch_size, -1, -1)], dim=-1)
  2169. return per_pos_features
  2170. def build_linear_positions(index_dims, output_range=(-1.0, 1.0)):
  2171. """
  2172. Generate an array of position indices for an N-D input array.
  2173. Args:
  2174. index_dims (`list[int]`):
  2175. The shape of the index dimensions of the input array.
  2176. output_range (`tuple[float]`, *optional*, defaults to `(-1.0, 1.0)`):
  2177. The min and max values taken by each input index dimension.
  2178. Returns:
  2179. `torch.FloatTensor` of shape `(index_dims[0], index_dims[1], .., index_dims[-1], N)`.
  2180. """
  2181. def _linspace(n_xels_per_dim):
  2182. return torch.linspace(start=output_range[0], end=output_range[1], steps=n_xels_per_dim, dtype=torch.float32)
  2183. dim_ranges = [_linspace(n_xels_per_dim) for n_xels_per_dim in index_dims]
  2184. array_index_grid = meshgrid(*dim_ranges, indexing="ij")
  2185. return torch.stack(array_index_grid, dim=-1)
  2186. class PerceiverAbstractPositionEncoding(nn.Module, metaclass=abc.ABCMeta):
  2187. """Perceiver abstract position encoding."""
  2188. @property
  2189. @abc.abstractmethod
  2190. def num_dimensions(self) -> int:
  2191. raise NotImplementedError
  2192. @abc.abstractmethod
  2193. def output_size(self, *args, **kwargs) -> int:
  2194. raise NotImplementedError
  2195. @abc.abstractmethod
  2196. def forward(self, batch_size, pos):
  2197. raise NotImplementedError
  2198. class PerceiverTrainablePositionEncoding(PerceiverAbstractPositionEncoding):
  2199. """Trainable position encoding."""
  2200. def __init__(self, index_dims, num_channels=128):
  2201. super().__init__()
  2202. self._num_channels = num_channels
  2203. self._index_dims = index_dims
  2204. index_dim = np.prod(index_dims)
  2205. self.position_embeddings = nn.Parameter(torch.randn(index_dim, num_channels))
  2206. @property
  2207. def num_dimensions(self) -> int:
  2208. if isinstance(self._index_dims, int):
  2209. return 1
  2210. return len(self._index_dims)
  2211. def output_size(self, *args, **kwargs) -> int:
  2212. return self._num_channels
  2213. def interpolate_pos_encoding(self, position_embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
  2214. num_positions = position_embeddings.shape[0]
  2215. new_height = new_width = torch_int(num_positions**0.5)
  2216. # always interpolate when tracing to ensure the exported model works for dynamic input shapes
  2217. if not torch.jit.is_tracing() and height == new_height and width == new_width:
  2218. return position_embeddings
  2219. position_embeddings = position_embeddings.reshape(1, new_height, new_width, self._num_channels).permute(
  2220. 0, 3, 1, 2
  2221. )
  2222. position_embeddings = nn.functional.interpolate(
  2223. position_embeddings,
  2224. size=(new_height, new_width),
  2225. mode="bicubic",
  2226. align_corners=False,
  2227. )
  2228. position_embeddings = position_embeddings.reshape(1, self._num_channels, -1).permute(0, 2, 1).squeeze(0)
  2229. return position_embeddings
  2230. def forward(
  2231. self, batch_size: int, interpolate_pos_encoding: bool = False, input_size: Optional[torch.Size] = None
  2232. ) -> torch.Tensor:
  2233. position_embeddings = self.position_embeddings
  2234. if interpolate_pos_encoding:
  2235. height, width = input_size
  2236. position_embeddings = self.interpolate_pos_encoding(position_embeddings, height, width)
  2237. if batch_size is not None:
  2238. position_embeddings = position_embeddings.expand(batch_size, -1, -1)
  2239. return position_embeddings
  2240. def _check_or_build_spatial_positions(pos, index_dims, batch_size):
  2241. """
  2242. Checks or builds spatial position features (x, y, ...).
  2243. Args:
  2244. pos (`torch.FloatTensor`):
  2245. None, or an array of position features. If None, position features are built. Otherwise, their size is checked.
  2246. index_dims (`list[int]`):
  2247. An iterable giving the spatial/index size of the data to be featurized.
  2248. batch_size (`int`):
  2249. The batch size of the data to be featurized.
  2250. Returns:
  2251. `torch.FloatTensor` of shape `(batch_size, prod(index_dims))` an array of position features.
  2252. """
  2253. if pos is None:
  2254. pos = build_linear_positions(index_dims)
  2255. # equivalent to `torch.broadcast_to(pos[None], (batch_size,) + pos.shape)`
  2256. # but `torch.broadcast_to` cannot be converted to ONNX
  2257. pos = pos[None].expand((batch_size,) + pos.shape)
  2258. pos = torch.reshape(pos, [batch_size, np.prod(index_dims), -1])
  2259. else:
  2260. # Just a warning label: you probably don't want your spatial features to
  2261. # have a different spatial layout than your pos coordinate system.
  2262. # But feel free to override if you think it'll work!
  2263. if pos.shape[-1] != len(index_dims):
  2264. raise ValueError("Spatial features have the wrong number of dimensions.")
  2265. return pos
  2266. class PerceiverFourierPositionEncoding(PerceiverAbstractPositionEncoding):
  2267. """Fourier (Sinusoidal) position encoding."""
  2268. def __init__(self, num_bands, max_resolution, concat_pos=True, sine_only=False):
  2269. super().__init__()
  2270. self.num_bands = num_bands
  2271. self.max_resolution = max_resolution
  2272. self.concat_pos = concat_pos
  2273. self.sine_only = sine_only
  2274. @property
  2275. def num_dimensions(self) -> int:
  2276. return len(self.max_resolution)
  2277. def output_size(self):
  2278. """Returns size of positional encodings last dimension."""
  2279. num_dims = len(self.max_resolution)
  2280. encoding_size = self.num_bands * num_dims
  2281. if not self.sine_only:
  2282. encoding_size *= 2
  2283. if self.concat_pos:
  2284. encoding_size += self.num_dimensions
  2285. return encoding_size
  2286. def forward(
  2287. self,
  2288. index_dims: list[int],
  2289. batch_size: int,
  2290. device: torch.device,
  2291. dtype: torch.dtype,
  2292. pos: Optional[torch.FloatTensor] = None,
  2293. ) -> torch.FloatTensor:
  2294. pos = _check_or_build_spatial_positions(pos, index_dims, batch_size)
  2295. fourier_pos_enc = generate_fourier_features(
  2296. pos,
  2297. num_bands=self.num_bands,
  2298. max_resolution=self.max_resolution,
  2299. concat_pos=self.concat_pos,
  2300. sine_only=self.sine_only,
  2301. ).to(device=device, dtype=dtype)
  2302. return fourier_pos_enc
  2303. class AbstractPreprocessor(nn.Module):
  2304. @property
  2305. def num_channels(self) -> int:
  2306. """Returns size of preprocessor output."""
  2307. raise NotImplementedError()
  2308. class PerceiverTextPreprocessor(AbstractPreprocessor):
  2309. """
  2310. Text preprocessing for Perceiver Encoder. Can be used to embed `inputs` and add positional encodings.
  2311. The dimensionality of the embeddings is determined by the `d_model` attribute of the configuration.
  2312. Args:
  2313. config ([`PerceiverConfig`]):
  2314. Model configuration.
  2315. """
  2316. def __init__(self, config: PerceiverConfig) -> None:
  2317. super().__init__()
  2318. self.config = config
  2319. self.embeddings = nn.Embedding(num_embeddings=config.vocab_size, embedding_dim=config.d_model)
  2320. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.d_model)
  2321. @property
  2322. def num_channels(self) -> int:
  2323. return self.config.d_model
  2324. def forward(
  2325. self,
  2326. inputs: torch.LongTensor,
  2327. pos: Optional[torch.Tensor] = None,
  2328. network_input_is_1d: bool = True,
  2329. interpolate_pos_encoding: bool = False,
  2330. ):
  2331. embeddings_without_pos = self.embeddings(inputs)
  2332. seq_length = inputs.shape[1]
  2333. position_ids = torch.arange(0, seq_length, device=inputs.device)
  2334. embeddings = embeddings_without_pos + self.position_embeddings(position_ids)
  2335. return embeddings, None, embeddings_without_pos
  2336. class PerceiverEmbeddingDecoder(nn.Module):
  2337. """
  2338. Module to decode embeddings (for masked language modeling).
  2339. Args:
  2340. config ([`PerceiverConfig`]):
  2341. Model configuration.
  2342. """
  2343. def __init__(self, config: PerceiverConfig) -> None:
  2344. super().__init__()
  2345. self.config = config
  2346. self.vocab_size = config.vocab_size
  2347. self.bias = nn.Parameter(torch.zeros(self.vocab_size))
  2348. def forward(self, hidden_states: torch.Tensor, embedding_layer: torch.Tensor) -> torch.Tensor:
  2349. batch_size, seq_len, d_model = hidden_states.shape
  2350. # Flatten batch dim
  2351. output = torch.matmul(hidden_states.reshape([-1, d_model]), embedding_layer.weight.transpose(0, 1))
  2352. output = output + self.bias
  2353. return output.reshape([batch_size, seq_len, self.vocab_size])
  2354. class PerceiverMultimodalPostprocessor(nn.Module):
  2355. """
  2356. Multimodal postprocessing for Perceiver. Can be used to combine modality-specific postprocessors into a single
  2357. postprocessor.
  2358. Args:
  2359. modalities (`Mapping[str, PostprocessorType]`):
  2360. Dictionary mapping modality name to postprocessor class for that modality.
  2361. input_is_dict (`bool`, *optional*, defaults to `False`):
  2362. If True, input is assumed to be dictionary structured, and outputs keep the same dictionary shape. If
  2363. False, input is a tensor which is sliced up during postprocessing by *modality_sizes*.
  2364. """
  2365. def __init__(self, modalities: Mapping[str, PostprocessorType], input_is_dict: bool = False):
  2366. super().__init__()
  2367. self.modalities = nn.ModuleDict(modalities)
  2368. self.input_is_dict = input_is_dict
  2369. def forward(
  2370. self, inputs: torch.Tensor, pos: Optional[torch.Tensor] = None, modality_sizes=None
  2371. ) -> Mapping[str, torch.Tensor]:
  2372. if not self.input_is_dict:
  2373. # Slice up modalities by their sizes.
  2374. if modality_sizes is None:
  2375. raise ValueError("Modality sizes should be specified if input is not a dictionary.")
  2376. inputs = restructure(modality_sizes=modality_sizes, inputs=inputs)
  2377. outputs = {
  2378. modality: postprocessor(inputs[modality], pos=pos, modality_sizes=None)
  2379. for modality, postprocessor in self.modalities.items()
  2380. }
  2381. return outputs
  2382. class PerceiverClassificationPostprocessor(nn.Module):
  2383. """
  2384. Classification postprocessing for Perceiver. Can be used to convert the decoder output to classification logits.
  2385. Args:
  2386. config ([*PerceiverConfig*]):
  2387. Model configuration.
  2388. in_channels (`int`):
  2389. Number of channels in the input.
  2390. """
  2391. def __init__(self, config: PerceiverConfig, in_channels: int) -> None:
  2392. super().__init__()
  2393. self.classifier = nn.Linear(in_channels, config.num_labels)
  2394. def forward(self, inputs, pos: Optional[torch.Tensor] = None, modality_sizes=None) -> torch.Tensor:
  2395. logits = self.classifier(inputs)
  2396. return logits[:, 0, :]
  2397. class PerceiverAudioPostprocessor(nn.Module):
  2398. """
  2399. Audio postprocessing for Perceiver. Can be used to convert the decoder output to audio features.
  2400. Args:
  2401. config ([*PerceiverConfig*]):
  2402. Model configuration.
  2403. in_channels (`int`):
  2404. Number of channels in the input.
  2405. postproc_type (`str`, *optional*, defaults to `"patches"`):
  2406. Postprocessor type to use. Currently, only "patches" is supported.
  2407. """
  2408. def __init__(self, config: PerceiverConfig, in_channels: int, postproc_type: str = "patches") -> None:
  2409. super().__init__()
  2410. if postproc_type != "patches": # to be supported: 'conv', 'patches', 'pixels'
  2411. raise ValueError("Invalid postproc_type!")
  2412. # Architecture parameters:
  2413. self.classifier = nn.Linear(in_channels, config.samples_per_patch)
  2414. def forward(self, inputs: torch.Tensor, pos: Optional[torch.Tensor] = None, modality_sizes=None) -> torch.Tensor:
  2415. logits = self.classifier(inputs)
  2416. return torch.reshape(logits, [inputs.shape[0], -1])
  2417. class PerceiverProjectionPostprocessor(nn.Module):
  2418. """
  2419. Projection postprocessing for Perceiver. Can be used to project the channels of the decoder output to a lower
  2420. dimension.
  2421. Args:
  2422. in_channels (`int`):
  2423. Number of channels in the input.
  2424. out_channels (`int`):
  2425. Number of channels in the output.
  2426. """
  2427. def __init__(self, in_channels: int, out_channels: int) -> None:
  2428. super().__init__()
  2429. self.classifier = nn.Linear(in_channels, out_channels)
  2430. def forward(self, inputs: torch.Tensor, pos: Optional[torch.Tensor] = None, modality_sizes=None) -> torch.Tensor:
  2431. logits = self.classifier(inputs)
  2432. return logits
  2433. class PerceiverImagePreprocessor(AbstractPreprocessor):
  2434. """
  2435. Image preprocessing for Perceiver Encoder.
  2436. Note: the *out_channels* argument refers to the output channels of a convolutional layer, if *prep_type* is set to
  2437. "conv1x1" or "conv". If one adds absolute position embeddings, one must make sure the *num_channels* of the
  2438. position encoding kwargs are set equal to the *out_channels*.
  2439. Args:
  2440. config ([*PerceiverConfig*]):
  2441. Model configuration.
  2442. prep_type (`str`, *optional*, defaults to `"conv"`):
  2443. Preprocessing type. Can be "conv1x1", "conv", "patches", "pixels".
  2444. spatial_downsample (`int`, *optional*, defaults to 4):
  2445. Spatial downsampling factor.
  2446. temporal_downsample (`int`, *optional*, defaults to 1):
  2447. Temporal downsampling factor (only relevant in case a time dimension is present).
  2448. position_encoding_type (`str`, *optional*, defaults to `"fourier"`):
  2449. Position encoding type. Can be "fourier" or "trainable".
  2450. in_channels (`int`, *optional*, defaults to 3):
  2451. Number of channels in the input.
  2452. out_channels (`int`, *optional*, defaults to 64):
  2453. Number of channels in the output.
  2454. conv_after_patching (`bool`, *optional*, defaults to `False`):
  2455. Whether to apply a convolutional layer after patching.
  2456. conv_after_patching_in_channels (`int`, *optional*, defaults to 54):
  2457. Number of channels in the input of the convolutional layer after patching.
  2458. conv2d_use_batchnorm (`bool`, *optional*, defaults to `True`):
  2459. Whether to use batch normalization in the convolutional layer.
  2460. concat_or_add_pos (`str`, *optional*, defaults to `"concat"`):
  2461. How to concatenate the position encoding to the input. Can be "concat" or "add".
  2462. project_pos_dim (`int`, *optional*, defaults to -1):
  2463. Dimension of the position encoding to project to. If -1, no projection is applied.
  2464. **position_encoding_kwargs (`Dict`, *optional*):
  2465. Keyword arguments for the position encoding.
  2466. """
  2467. def __init__(
  2468. self,
  2469. config,
  2470. prep_type="conv",
  2471. spatial_downsample: int = 4,
  2472. temporal_downsample: int = 1,
  2473. position_encoding_type: str = "fourier",
  2474. in_channels: int = 3,
  2475. out_channels: int = 64,
  2476. conv_after_patching: bool = False,
  2477. conv_after_patching_in_channels: int = 54, # only relevant when conv_after_patching = True
  2478. conv2d_use_batchnorm: bool = True,
  2479. concat_or_add_pos: str = "concat",
  2480. project_pos_dim: int = -1,
  2481. **position_encoding_kwargs,
  2482. ):
  2483. super().__init__()
  2484. self.config = config
  2485. if prep_type not in ("conv", "patches", "pixels", "conv1x1"):
  2486. raise ValueError(f"Prep_type {prep_type} is invalid")
  2487. if concat_or_add_pos not in ["concat", "add"]:
  2488. raise ValueError(f"Invalid value {concat_or_add_pos} for concat_or_add_pos.")
  2489. self.in_channels = in_channels
  2490. self.prep_type = prep_type
  2491. self.spatial_downsample = spatial_downsample
  2492. self.temporal_downsample = temporal_downsample
  2493. self.position_encoding_type = position_encoding_type
  2494. self.concat_or_add_pos = concat_or_add_pos
  2495. self.conv_after_patching = conv_after_patching
  2496. self.out_channels = out_channels
  2497. if self.prep_type == "conv":
  2498. # Downsampling with conv is currently restricted
  2499. convnet_num_layers = math.log(spatial_downsample, 4)
  2500. convnet_num_layers_is_int = convnet_num_layers == np.round(convnet_num_layers)
  2501. if not convnet_num_layers_is_int or temporal_downsample != 1:
  2502. raise ValueError(
  2503. "Only powers of 4 expected for spatial and 1 expected for temporal downsampling with conv."
  2504. )
  2505. self.convnet = Conv2DDownsample(
  2506. in_channels=in_channels,
  2507. num_layers=int(convnet_num_layers),
  2508. out_channels=out_channels,
  2509. use_batchnorm=conv2d_use_batchnorm,
  2510. )
  2511. elif self.prep_type == "conv1x1":
  2512. if temporal_downsample != 1:
  2513. raise ValueError("Conv1x1 does not downsample in time.")
  2514. self.convnet_1x1 = nn.Conv2d(
  2515. in_channels=in_channels,
  2516. out_channels=out_channels,
  2517. kernel_size=(1, 1),
  2518. # spatial_downsample is unconstrained for 1x1 convolutions.
  2519. stride=(spatial_downsample, spatial_downsample),
  2520. )
  2521. # Position embeddings
  2522. self.project_pos_dim = project_pos_dim
  2523. self.position_embeddings, self.positions_projection = build_position_encoding(
  2524. position_encoding_type=position_encoding_type,
  2525. out_channels=out_channels,
  2526. project_pos_dim=project_pos_dim,
  2527. **position_encoding_kwargs,
  2528. )
  2529. # Optional convolutional layer after patches.
  2530. self.conv_after_patches = (
  2531. nn.Linear(conv_after_patching_in_channels, self.out_channels) if conv_after_patching else nn.Identity()
  2532. )
  2533. @property
  2534. def num_channels(self) -> int:
  2535. # Let's assume that the number of resolutions (in the context of image preprocessing)
  2536. # of the input data is 2 or 3 depending on whether we are processing image or video respectively.
  2537. # In this case, for convenience, we will declare is_temporal variable,
  2538. # which will show whether the data has a temporal dimension or not.
  2539. is_temporal = self.position_embeddings.num_dimensions > 2
  2540. # position embedding
  2541. if self.project_pos_dim > 0:
  2542. pos_dim = self.project_pos_dim
  2543. else:
  2544. pos_dim = self.position_embeddings.output_size()
  2545. if self.concat_or_add_pos == "add":
  2546. return pos_dim
  2547. # inputs
  2548. if self.conv_after_patching or self.prep_type in ("conv1x1", "conv"):
  2549. inp_dim = self.out_channels
  2550. elif self.prep_type == "pixels":
  2551. inp_dim = self.in_channels
  2552. if not is_temporal:
  2553. inp_dim = math.ceil(inp_dim / self.spatial_downsample)
  2554. elif self.prep_type == "patches":
  2555. if self.conv_after_patching:
  2556. inp_dim = self.out_channels
  2557. else:
  2558. inp_dim = self.in_channels * self.spatial_downsample**2
  2559. if is_temporal:
  2560. inp_dim *= self.temporal_downsample
  2561. return inp_dim + pos_dim
  2562. def _build_network_inputs(
  2563. self, inputs: torch.Tensor, network_input_is_1d: bool = True, interpolate_pos_encoding: bool = False
  2564. ):
  2565. """
  2566. Construct the final input, including position encoding.
  2567. This method expects the inputs to always have channels as last dimension.
  2568. """
  2569. batch_size = inputs.shape[0]
  2570. input_size = inputs.shape[1:3]
  2571. index_dims = inputs.shape[1:-1]
  2572. indices = np.prod(index_dims)
  2573. # Flatten input features to a 1D index dimension if necessary.
  2574. if len(inputs.shape) > 3 and network_input_is_1d:
  2575. inputs = torch.reshape(inputs, [batch_size, indices, -1])
  2576. # Construct the position encoding.
  2577. if self.position_encoding_type == "trainable":
  2578. pos_enc = self.position_embeddings(batch_size, interpolate_pos_encoding, input_size)
  2579. elif self.position_encoding_type == "fourier":
  2580. pos_enc = self.position_embeddings(index_dims, batch_size, device=inputs.device, dtype=inputs.dtype)
  2581. # Optionally project them to a target dimension.
  2582. pos_enc = self.positions_projection(pos_enc)
  2583. if not network_input_is_1d:
  2584. # Reshape pos to match the input feature shape
  2585. # if the network takes non-1D inputs
  2586. sh = inputs.shape
  2587. pos_enc = torch.reshape(pos_enc, list(sh)[:-1] + [-1])
  2588. if self.concat_or_add_pos == "concat":
  2589. inputs_with_pos = torch.cat([inputs, pos_enc], dim=-1)
  2590. elif self.concat_or_add_pos == "add":
  2591. inputs_with_pos = inputs + pos_enc
  2592. return inputs_with_pos, inputs
  2593. def forward(
  2594. self,
  2595. inputs: torch.Tensor,
  2596. pos: Optional[torch.Tensor] = None,
  2597. network_input_is_1d: bool = True,
  2598. interpolate_pos_encoding: bool = False,
  2599. ):
  2600. if self.prep_type == "conv":
  2601. # Convnet image featurization.
  2602. # Downsamples spatially by a factor of 4
  2603. inputs = self.convnet(inputs)
  2604. elif self.prep_type == "conv1x1":
  2605. # map inputs to self.out_channels
  2606. inputs = self.convnet_1x1(inputs)
  2607. elif self.prep_type == "pixels":
  2608. # if requested, downsamples in the crudest way
  2609. if inputs.ndim == 4:
  2610. inputs = inputs[:: self.spatial_downsample, :: self.spatial_downsample]
  2611. elif inputs.ndim == 5:
  2612. inputs = inputs[
  2613. :, :: self.temporal_downsample, :, :: self.spatial_downsample, :: self.spatial_downsample
  2614. ]
  2615. else:
  2616. raise ValueError("Unsupported data format for pixels.")
  2617. elif self.prep_type == "patches":
  2618. # Space2depth featurization.
  2619. # Video: B x T x C x H x W
  2620. inputs = space_to_depth(
  2621. inputs, temporal_block_size=self.temporal_downsample, spatial_block_size=self.spatial_downsample
  2622. )
  2623. if inputs.ndim == 5 and inputs.shape[1] == 1:
  2624. # for flow
  2625. inputs = inputs.squeeze(dim=1)
  2626. # Optionally apply conv layer.
  2627. inputs = self.conv_after_patches(inputs)
  2628. if self.prep_type != "patches":
  2629. # move channels to last dimension, as the _build_network_inputs method below expects this
  2630. if inputs.ndim == 4:
  2631. inputs = inputs.permute(0, 2, 3, 1)
  2632. elif inputs.ndim == 5:
  2633. inputs = inputs.permute(0, 1, 3, 4, 2)
  2634. else:
  2635. raise ValueError("Unsupported data format for conv1x1.")
  2636. inputs, inputs_without_pos = self._build_network_inputs(inputs, network_input_is_1d, interpolate_pos_encoding)
  2637. modality_sizes = None # Size for each modality, only needed for multimodal
  2638. return inputs, modality_sizes, inputs_without_pos
  2639. class PerceiverOneHotPreprocessor(AbstractPreprocessor):
  2640. """
  2641. One-hot preprocessor for Perceiver Encoder. Can be used to add a dummy index dimension to the input.
  2642. Args:
  2643. config ([`PerceiverConfig`]):
  2644. Model configuration.
  2645. """
  2646. def __init__(self, config: PerceiverConfig) -> None:
  2647. super().__init__()
  2648. self.config: PerceiverConfig = config
  2649. @property
  2650. def num_channels(self) -> int:
  2651. return self.config.num_labels
  2652. def forward(self, inputs: torch.Tensor, pos: Optional[torch.Tensor] = None, network_input_is_1d: bool = True):
  2653. # Add a dummy index dimension.
  2654. inputs = inputs[:, None, :]
  2655. # No position encodings, so the 1st (input) and 3rd (inputs_without_pos)
  2656. # outputs are identical.
  2657. return inputs, None, inputs
  2658. class PerceiverAudioPreprocessor(AbstractPreprocessor):
  2659. """
  2660. Audio preprocessing for Perceiver Encoder.
  2661. Args:
  2662. config ([*PerceiverConfig*]):
  2663. Model configuration.
  2664. prep_type (`str`, *optional*, defaults to `"patches"`):
  2665. Preprocessor type to use. Only "patches" is supported.
  2666. samples_per_patch (`int`, *optional*, defaults to 96):
  2667. Number of samples per patch.
  2668. position_encoding_type (`str`, *optional*, defaults to `"fourier"`):
  2669. Type of position encoding to use. Can be "trainable" or "fourier".
  2670. concat_or_add_pos (`str`, *optional*, defaults to `"concat"`):
  2671. How to concatenate the position encoding to the input. Can be "concat" or "add".
  2672. out_channels (`int`, *optional*, defaults to 64):
  2673. Number of channels in the output.
  2674. project_pos_dim (`int`, *optional*, defaults to -1):
  2675. Dimension of the position encoding to project to. If -1, no projection is applied.
  2676. **position_encoding_kwargs (`Dict`, *optional*):
  2677. Keyword arguments for the position encoding.
  2678. """
  2679. def __init__(
  2680. self,
  2681. config,
  2682. prep_type: str = "patches",
  2683. samples_per_patch: int = 96,
  2684. position_encoding_type: str = "fourier",
  2685. concat_or_add_pos: str = "concat",
  2686. out_channels=64,
  2687. project_pos_dim=-1,
  2688. **position_encoding_kwargs,
  2689. ):
  2690. super().__init__()
  2691. self.config = config
  2692. if prep_type != "patches":
  2693. raise ValueError(f"Prep_type {prep_type} is invalid, can only be 'patches'.")
  2694. if concat_or_add_pos not in ["concat", "add"]:
  2695. raise ValueError(f"Concat_or_pos {concat_or_add_pos} is invalid, can only be 'concat' or 'add'.")
  2696. self.samples_per_patch = samples_per_patch
  2697. self.position_encoding_type = position_encoding_type
  2698. self.concat_or_add_pos = concat_or_add_pos
  2699. self.project_pos_dim = project_pos_dim
  2700. # Position embeddings
  2701. self.position_embeddings, self.positions_projection = build_position_encoding(
  2702. position_encoding_type=position_encoding_type,
  2703. out_channels=out_channels,
  2704. project_pos_dim=project_pos_dim,
  2705. **position_encoding_kwargs,
  2706. )
  2707. @property
  2708. def num_channels(self) -> int:
  2709. # position embedding
  2710. if self.project_pos_dim > 0:
  2711. pos_dim = self.project_pos_dim
  2712. else:
  2713. pos_dim = self.position_embeddings.output_size()
  2714. if self.concat_or_add_pos == "add":
  2715. return pos_dim
  2716. return self.samples_per_patch + pos_dim
  2717. def _build_network_inputs(self, inputs):
  2718. """Construct the final input, including position encoding."""
  2719. batch_size = inputs.shape[0]
  2720. index_dims = inputs.shape[1:-1]
  2721. # Construct the position encoding.
  2722. if self.position_encoding_type == "trainable":
  2723. pos_enc = self.position_embeddings(batch_size)
  2724. elif self.position_encoding_type == "fourier":
  2725. pos_enc = self.position_embeddings(index_dims, batch_size, device=inputs.device, dtype=inputs.dtype)
  2726. # Optionally project them to a target dimension.
  2727. pos_enc = self.positions_projection(pos_enc)
  2728. if self.concat_or_add_pos == "concat":
  2729. inputs_with_pos = torch.cat([inputs, pos_enc], dim=-1)
  2730. elif self.concat_or_add_pos == "add":
  2731. inputs_with_pos = inputs + pos_enc
  2732. return inputs_with_pos, inputs
  2733. def forward(
  2734. self,
  2735. inputs: torch.Tensor,
  2736. pos: Optional[torch.Tensor] = None,
  2737. network_input_is_1d: bool = True,
  2738. interpolate_pos_encoding: bool = False,
  2739. ):
  2740. inputs = torch.reshape(inputs, [inputs.shape[0], -1, self.samples_per_patch])
  2741. inputs, inputs_without_pos = self._build_network_inputs(inputs)
  2742. modality_sizes = None # Size for each modality, only needed for multimodal
  2743. return inputs, modality_sizes, inputs_without_pos
  2744. class PerceiverMultimodalPreprocessor(AbstractPreprocessor):
  2745. """
  2746. Multimodal preprocessing for Perceiver Encoder.
  2747. Inputs for each modality are preprocessed, then padded with trainable position embeddings to have the same number
  2748. of channels.
  2749. Args:
  2750. modalities (`Mapping[str, PreprocessorType]`):
  2751. Dict mapping modality name to preprocessor.
  2752. mask_probs (`dict[str, float]`):
  2753. Dict mapping modality name to masking probability of that modality.
  2754. min_padding_size (`int`, *optional*, defaults to 2):
  2755. The minimum padding size for all modalities. The final output will have num_channels equal to the maximum
  2756. channels across all modalities plus min_padding_size.
  2757. """
  2758. def __init__(
  2759. self,
  2760. modalities: Mapping[str, PreprocessorType],
  2761. mask_probs: Optional[Mapping[str, float]] = None,
  2762. min_padding_size: int = 2,
  2763. ):
  2764. super().__init__()
  2765. self.modalities = nn.ModuleDict(modalities)
  2766. self.min_padding_size = min_padding_size
  2767. self.mask_probs = mask_probs if mask_probs is not None else {}
  2768. self.padding = nn.ParameterDict(
  2769. {
  2770. modality: nn.Parameter(torch.randn(1, self.num_channels - preprocessor.num_channels))
  2771. for modality, preprocessor in modalities.items()
  2772. }
  2773. )
  2774. self.mask = nn.ParameterDict(
  2775. {modality: nn.Parameter(torch.randn(1, self.num_channels)) for modality, _ in self.mask_probs.items()}
  2776. )
  2777. @property
  2778. def num_channels(self) -> int:
  2779. max_channel_size = max(processor.num_channels for _, processor in self.modalities.items())
  2780. common_channel_size = max_channel_size + self.min_padding_size
  2781. return common_channel_size
  2782. def forward(
  2783. self,
  2784. inputs: Mapping[str, torch.Tensor],
  2785. pos: Optional[torch.Tensor] = None,
  2786. network_input_is_1d: bool = True,
  2787. interpolate_pos_encoding: bool = False,
  2788. ) -> PreprocessorOutputType:
  2789. padded = {}
  2790. modality_sizes = {}
  2791. inputs_without_pos = {}
  2792. for modality, preprocessor in self.modalities.items():
  2793. # preprocess each modality using the respective preprocessor.
  2794. output, _, inputs_without_pos[modality] = preprocessor(
  2795. inputs[modality], pos=pos, network_input_is_1d=network_input_is_1d
  2796. )
  2797. # pad to the same common_channel_size.
  2798. batch_size, num_samples, num_channels = output.shape
  2799. pos_enc = self.padding[modality].expand(batch_size, -1, -1)
  2800. padding = torch.broadcast_to(
  2801. pos_enc,
  2802. [batch_size, num_samples, self.num_channels - num_channels],
  2803. )
  2804. output_padded = torch.cat([output, padding], dim=2)
  2805. # mask if required
  2806. if modality in self.mask_probs:
  2807. mask_token = self.mask[modality].expand(batch_size, -1, -1)
  2808. mask_prob = self.mask_probs[modality]
  2809. mask = torch.bernoulli(torch.full([batch_size, num_samples], mask_prob))
  2810. mask = torch.unsqueeze(mask, dim=2).to(mask_token.device)
  2811. output_padded = (1 - mask) * output_padded + mask * mask_token
  2812. padded[modality] = output_padded
  2813. modality_sizes[modality] = output_padded.shape[1]
  2814. # Apply a predictable ordering to the modalities
  2815. padded_ls = [padded[k] for k in sorted(padded.keys())]
  2816. # Finally, concatenate along the time dimension
  2817. final_inputs = torch.cat(padded_ls, dim=1)
  2818. return final_inputs, modality_sizes, inputs_without_pos
  2819. __all__ = [
  2820. "PerceiverForImageClassificationConvProcessing",
  2821. "PerceiverForImageClassificationFourier",
  2822. "PerceiverForImageClassificationLearned",
  2823. "PerceiverForMaskedLM",
  2824. "PerceiverForMultimodalAutoencoding",
  2825. "PerceiverForOpticalFlow",
  2826. "PerceiverForSequenceClassification",
  2827. "PerceiverLayer",
  2828. "PerceiverModel",
  2829. "PerceiverPreTrainedModel",
  2830. ]