| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878287928802881288228832884288528862887288828892890289128922893289428952896289728982899290029012902290329042905290629072908290929102911291229132914291529162917291829192920292129222923292429252926292729282929293029312932293329342935293629372938293929402941294229432944294529462947294829492950295129522953295429552956295729582959296029612962296329642965296629672968296929702971297229732974297529762977297829792980298129822983298429852986298729882989299029912992299329942995299629972998299930003001300230033004300530063007300830093010301130123013301430153016301730183019302030213022302330243025302630273028302930303031303230333034303530363037303830393040304130423043304430453046304730483049305030513052305330543055305630573058305930603061306230633064306530663067306830693070307130723073307430753076307730783079308030813082308330843085308630873088308930903091309230933094309530963097309830993100310131023103310431053106310731083109311031113112311331143115311631173118311931203121312231233124312531263127312831293130313131323133313431353136313731383139314031413142314331443145314631473148314931503151315231533154315531563157315831593160316131623163316431653166316731683169317031713172317331743175317631773178317931803181318231833184318531863187318831893190319131923193319431953196319731983199320032013202320332043205320632073208320932103211321232133214321532163217321832193220322132223223322432253226322732283229323032313232323332343235323632373238323932403241324232433244324532463247324832493250325132523253325432553256325732583259326032613262326332643265326632673268326932703271327232733274327532763277327832793280328132823283328432853286328732883289329032913292329332943295329632973298329933003301330233033304330533063307330833093310331133123313331433153316331733183319332033213322332333243325332633273328332933303331333233333334333533363337333833393340334133423343334433453346334733483349335033513352335333543355335633573358335933603361336233633364336533663367336833693370337133723373337433753376337733783379338033813382338333843385338633873388338933903391339233933394339533963397339833993400340134023403340434053406340734083409341034113412341334143415341634173418341934203421342234233424342534263427342834293430343134323433343434353436343734383439344034413442344334443445344634473448344934503451345234533454345534563457345834593460346134623463346434653466346734683469347034713472347334743475347634773478347934803481348234833484348534863487348834893490349134923493349434953496349734983499350035013502350335043505350635073508350935103511351235133514351535163517351835193520352135223523352435253526352735283529353035313532353335343535353635373538353935403541354235433544354535463547354835493550355135523553355435553556355735583559356035613562356335643565356635673568356935703571357235733574357535763577357835793580358135823583358435853586358735883589359035913592359335943595359635973598359936003601360236033604360536063607360836093610361136123613361436153616361736183619362036213622362336243625362636273628362936303631363236333634363536363637363836393640364136423643364436453646364736483649365036513652365336543655365636573658365936603661366236633664366536663667366836693670367136723673367436753676367736783679368036813682368336843685368636873688368936903691369236933694369536963697369836993700370137023703370437053706370737083709371037113712371337143715371637173718371937203721372237233724372537263727372837293730373137323733373437353736373737383739374037413742374337443745374637473748374937503751375237533754375537563757375837593760376137623763376437653766376737683769377037713772377337743775377637773778377937803781378237833784378537863787378837893790379137923793379437953796379737983799380038013802380338043805380638073808380938103811381238133814381538163817381838193820382138223823382438253826382738283829383038313832383338343835383638373838383938403841384238433844384538463847384838493850385138523853385438553856385738583859386038613862386338643865386638673868386938703871387238733874387538763877387838793880388138823883388438853886388738883889389038913892389338943895389638973898389939003901390239033904390539063907390839093910391139123913391439153916391739183919392039213922392339243925392639273928392939303931393239333934393539363937393839393940394139423943394439453946394739483949395039513952395339543955395639573958395939603961396239633964396539663967396839693970397139723973397439753976397739783979398039813982398339843985398639873988398939903991399239933994399539963997399839994000400140024003400440054006400740084009401040114012401340144015401640174018401940204021402240234024402540264027402840294030403140324033403440354036403740384039404040414042404340444045404640474048404940504051405240534054405540564057405840594060406140624063406440654066406740684069407040714072407340744075407640774078407940804081408240834084408540864087408840894090409140924093409440954096409740984099410041014102410341044105410641074108410941104111411241134114411541164117411841194120412141224123412441254126412741284129413041314132413341344135413641374138413941404141414241434144414541464147414841494150415141524153415441554156415741584159416041614162416341644165416641674168416941704171417241734174417541764177417841794180418141824183418441854186418741884189419041914192419341944195419641974198419942004201420242034204420542064207420842094210421142124213421442154216421742184219422042214222422342244225422642274228422942304231423242334234423542364237423842394240424142424243424442454246424742484249425042514252425342544255425642574258425942604261426242634264426542664267426842694270427142724273427442754276427742784279428042814282428342844285428642874288428942904291429242934294429542964297429842994300430143024303430443054306430743084309431043114312431343144315431643174318431943204321432243234324432543264327432843294330433143324333433443354336433743384339434043414342434343444345434643474348434943504351435243534354435543564357435843594360436143624363436443654366436743684369437043714372437343744375437643774378437943804381438243834384438543864387438843894390439143924393439443954396439743984399440044014402440344044405440644074408440944104411441244134414441544164417441844194420442144224423442444254426442744284429443044314432443344344435443644374438443944404441444244434444444544464447444844494450445144524453445444554456445744584459446044614462446344644465446644674468446944704471447244734474447544764477447844794480448144824483448444854486448744884489449044914492449344944495449644974498449945004501450245034504450545064507450845094510451145124513451445154516451745184519452045214522452345244525452645274528452945304531453245334534453545364537453845394540454145424543454445454546454745484549455045514552455345544555455645574558455945604561456245634564456545664567456845694570457145724573457445754576457745784579458045814582458345844585458645874588458945904591459245934594459545964597459845994600460146024603460446054606460746084609461046114612461346144615461646174618461946204621462246234624462546264627462846294630463146324633463446354636463746384639464046414642464346444645464646474648464946504651465246534654465546564657465846594660466146624663466446654666466746684669467046714672467346744675467646774678467946804681468246834684468546864687468846894690469146924693469446954696469746984699470047014702470347044705470647074708470947104711471247134714471547164717471847194720472147224723472447254726472747284729473047314732473347344735473647374738473947404741474247434744474547464747474847494750475147524753475447554756475747584759476047614762476347644765476647674768476947704771477247734774477547764777477847794780478147824783478447854786478747884789479047914792479347944795479647974798479948004801480248034804480548064807480848094810481148124813481448154816481748184819482048214822482348244825482648274828482948304831483248334834483548364837483848394840484148424843484448454846484748484849485048514852485348544855485648574858485948604861486248634864486548664867486848694870487148724873487448754876487748784879488048814882488348844885488648874888488948904891489248934894489548964897489848994900490149024903490449054906490749084909491049114912491349144915491649174918491949204921492249234924492549264927492849294930493149324933493449354936493749384939494049414942494349444945494649474948494949504951495249534954495549564957495849594960496149624963496449654966496749684969497049714972497349744975497649774978497949804981498249834984498549864987498849894990499149924993499449954996499749984999500050015002500350045005500650075008500950105011501250135014501550165017501850195020502150225023502450255026502750285029503050315032503350345035503650375038503950405041504250435044504550465047504850495050505150525053505450555056505750585059506050615062506350645065506650675068506950705071507250735074507550765077507850795080508150825083508450855086508750885089509050915092509350945095509650975098509951005101510251035104510551065107510851095110511151125113511451155116511751185119512051215122512351245125512651275128512951305131513251335134513551365137513851395140514151425143514451455146514751485149515051515152515351545155515651575158515951605161516251635164516551665167516851695170517151725173517451755176517751785179518051815182518351845185518651875188518951905191519251935194519551965197519851995200520152025203520452055206520752085209521052115212521352145215521652175218521952205221522252235224522552265227522852295230523152325233523452355236523752385239524052415242524352445245524652475248524952505251525252535254525552565257525852595260526152625263526452655266526752685269527052715272527352745275527652775278527952805281528252835284528552865287528852895290529152925293529452955296529752985299530053015302530353045305530653075308530953105311531253135314531553165317531853195320532153225323532453255326532753285329533053315332533353345335533653375338533953405341534253435344534553465347534853495350535153525353535453555356535753585359536053615362536353645365536653675368536953705371537253735374537553765377537853795380538153825383538453855386538753885389539053915392539353945395539653975398539954005401540254035404540554065407540854095410541154125413541454155416541754185419542054215422542354245425542654275428542954305431543254335434543554365437543854395440544154425443544454455446544754485449545054515452545354545455545654575458545954605461546254635464546554665467546854695470547154725473547454755476547754785479548054815482548354845485548654875488548954905491549254935494549554965497549854995500550155025503550455055506550755085509551055115512551355145515551655175518551955205521552255235524552555265527552855295530553155325533553455355536553755385539554055415542554355445545554655475548554955505551555255535554555555565557555855595560556155625563556455655566556755685569557055715572557355745575557655775578557955805581558255835584558555865587558855895590559155925593559455955596559755985599560056015602560356045605560656075608560956105611561256135614561556165617561856195620562156225623562456255626562756285629563056315632563356345635563656375638563956405641564256435644564556465647564856495650565156525653565456555656565756585659566056615662566356645665566656675668566956705671567256735674567556765677567856795680568156825683568456855686568756885689569056915692569356945695569656975698569957005701570257035704570557065707570857095710571157125713571457155716571757185719572057215722572357245725572657275728572957305731573257335734573557365737573857395740574157425743574457455746574757485749575057515752575357545755575657575758575957605761576257635764576557665767576857695770577157725773577457755776577757785779578057815782578357845785578657875788578957905791579257935794579557965797579857995800580158025803580458055806580758085809581058115812581358145815581658175818581958205821582258235824582558265827582858295830583158325833583458355836583758385839584058415842584358445845584658475848584958505851585258535854585558565857585858595860586158625863586458655866586758685869587058715872587358745875587658775878587958805881588258835884588558865887588858895890589158925893589458955896589758985899590059015902590359045905590659075908590959105911591259135914591559165917591859195920592159225923592459255926592759285929593059315932593359345935593659375938593959405941594259435944594559465947594859495950595159525953595459555956595759585959596059615962596359645965596659675968596959705971597259735974597559765977597859795980598159825983598459855986598759885989599059915992599359945995599659975998599960006001600260036004600560066007600860096010601160126013601460156016601760186019602060216022602360246025602660276028602960306031603260336034603560366037603860396040604160426043604460456046604760486049605060516052605360546055605660576058605960606061606260636064606560666067606860696070607160726073607460756076607760786079608060816082608360846085608660876088608960906091609260936094609560966097609860996100610161026103610461056106610761086109611061116112611361146115611661176118611961206121612261236124612561266127612861296130613161326133613461356136613761386139614061416142614361446145614661476148614961506151615261536154615561566157615861596160616161626163616461656166616761686169617061716172617361746175617661776178617961806181618261836184618561866187618861896190619161926193619461956196619761986199620062016202620362046205620662076208620962106211621262136214621562166217621862196220622162226223622462256226622762286229623062316232623362346235623662376238623962406241624262436244624562466247624862496250625162526253625462556256625762586259626062616262626362646265626662676268626962706271627262736274627562766277627862796280628162826283628462856286628762886289629062916292629362946295629662976298629963006301630263036304630563066307630863096310631163126313631463156316631763186319632063216322632363246325632663276328632963306331633263336334633563366337633863396340634163426343634463456346634763486349635063516352635363546355635663576358635963606361636263636364636563666367636863696370637163726373637463756376637763786379638063816382638363846385638663876388638963906391639263936394639563966397639863996400640164026403640464056406640764086409641064116412641364146415641664176418641964206421642264236424642564266427642864296430643164326433643464356436643764386439644064416442644364446445644664476448644964506451645264536454645564566457645864596460646164626463646464656466646764686469647064716472647364746475647664776478647964806481648264836484648564866487648864896490649164926493649464956496649764986499 |
- """Functional interface."""
- import importlib
- import math
- import warnings
- from typing import Callable, Optional, TYPE_CHECKING, Union
- import torch
- from torch import _VF, sym_int as _sym_int, Tensor
- from torch._C import _add_docstr, _infer_size
- from torch._jit_internal import (
- _overload,
- boolean_dispatch,
- BroadcastingList1,
- BroadcastingList2,
- BroadcastingList3,
- )
- from torch._torch_docs import reproducibility_notes, sparse_support_notes, tf32_notes
- from torch.nn import _reduction as _Reduction, grad # noqa: F401
- from torch.nn.modules.utils import _list_with_default, _pair, _single, _triple
- from torch.overrides import (
- handle_torch_function,
- has_torch_function,
- has_torch_function_unary,
- has_torch_function_variadic,
- )
- if TYPE_CHECKING:
- from torch.types import _dtype as DType
- else:
- # The JIT doesn't understand Union, nor torch.dtype here
- DType = int
- try:
- import numpy as np
- except ModuleNotFoundError:
- np = None
- conv1d = _add_docstr(
- torch.conv1d,
- r"""
- conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor
- Applies a 1D convolution over an input signal composed of several input
- planes.
- {tf32_note}
- See :class:`~torch.nn.Conv1d` for details and output shape.
- Note:
- {cudnn_reproducibility_note}
- Note:
- This operator supports complex data types i.e. ``complex32, complex64, complex128``.
- """.format(**reproducibility_notes, **tf32_notes)
- + r"""
- Args:
- input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iW)`
- weight: filters of shape :math:`(\text{out\_channels} , \frac{\text{in\_channels}}{\text{groups}} , kW)`
- bias: optional bias of shape :math:`(\text{out\_channels})`. Default: ``None``
- stride: the stride of the convolving kernel. Can be a single number or
- a one-element tuple `(sW,)`. Default: 1
- padding: implicit paddings on both sides of the input. Can be a string {'valid', 'same'},
- single number or a one-element tuple `(padW,)`. Default: 0
- ``padding='valid'`` is the same as no padding. ``padding='same'`` pads
- the input so the output has the same shape as the input. However, this mode
- doesn't support any stride values other than 1.
- .. warning::
- For ``padding='same'``, if the ``weight`` is even-length and
- ``dilation`` is odd in any dimension, a full :func:`pad` operation
- may be needed internally. Lowering performance.
- dilation: the spacing between kernel elements. Can be a single number or
- a one-element tuple `(dW,)`. Default: 1
- groups: split input into groups, :math:`\text{in\_channels}` should be divisible by
- the number of groups. Default: 1
- Examples::
- >>> inputs = torch.randn(33, 16, 30)
- >>> filters = torch.randn(20, 16, 5)
- >>> F.conv1d(inputs, filters)
- """,
- )
- conv2d = _add_docstr(
- torch.conv2d,
- r"""
- conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor
- Applies a 2D convolution over an input image composed of several input
- planes.
- {tf32_note}
- See :class:`~torch.nn.Conv2d` for details and output shape.
- Note:
- {cudnn_reproducibility_note}
- Note:
- This operator supports complex data types i.e. ``complex32, complex64, complex128``.
- """.format(**reproducibility_notes, **tf32_notes)
- + r"""
- Args:
- input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)`
- weight: filters of shape :math:`(\text{out\_channels} , \frac{\text{in\_channels}}{\text{groups}} , kH , kW)`
- bias: optional bias tensor of shape :math:`(\text{out\_channels})`. Default: ``None``
- stride: the stride of the convolving kernel. Can be a single number or a
- tuple `(sH, sW)`. Default: 1
- padding: implicit paddings on both sides of the input. Can be a string {'valid', 'same'},
- single number or a tuple `(padH, padW)`. Default: 0
- ``padding='valid'`` is the same as no padding. ``padding='same'`` pads
- the input so the output has the same shape as the input. However, this mode
- doesn't support any stride values other than 1.
- .. warning::
- For ``padding='same'``, if the ``weight`` is even-length and
- ``dilation`` is odd in any dimension, a full :func:`pad` operation
- may be needed internally. Lowering performance.
- dilation: the spacing between kernel elements. Can be a single number or
- a tuple `(dH, dW)`. Default: 1
- groups: split input into groups, both :math:`\text{in\_channels}` and :math:`\text{out\_channels}`
- should be divisible by the number of groups. Default: 1
- Examples::
- >>> # With square kernels and equal stride
- >>> filters = torch.randn(8, 4, 3, 3)
- >>> inputs = torch.randn(1, 4, 5, 5)
- >>> F.conv2d(inputs, filters, padding=1)
- """,
- ) # noqa: E501
- conv3d = _add_docstr(
- torch.conv3d,
- r"""
- conv3d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor
- Applies a 3D convolution over an input image composed of several input
- planes.
- {tf32_note}
- See :class:`~torch.nn.Conv3d` for details and output shape.
- Note:
- {cudnn_reproducibility_note}
- Note:
- This operator supports complex data types i.e. ``complex32, complex64, complex128``.
- """.format(**reproducibility_notes, **tf32_notes)
- + r"""
- Args:
- input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iT , iH , iW)`
- weight: filters of shape :math:`(\text{out\_channels} , \frac{\text{in\_channels}}{\text{groups}} , kT , kH , kW)`
- bias: optional bias tensor of shape :math:`(\text{out\_channels})`. Default: None
- stride: the stride of the convolving kernel. Can be a single number or a
- tuple `(sT, sH, sW)`. Default: 1
- padding: implicit paddings on both sides of the input. Can be a string {'valid', 'same'},
- single number or a tuple `(padT, padH, padW)`. Default: 0
- ``padding='valid'`` is the same as no padding. ``padding='same'`` pads
- the input so the output has the same shape as the input. However, this mode
- doesn't support any stride values other than 1.
- .. warning::
- For ``padding='same'``, if the ``weight`` is even-length and
- ``dilation`` is odd in any dimension, a full :func:`pad` operation
- may be needed internally. Lowering performance.
- dilation: the spacing between kernel elements. Can be a single number or
- a tuple `(dT, dH, dW)`. Default: 1
- groups: split input into groups, :math:`\text{in\_channels}` should be divisible by
- the number of groups. Default: 1
- Examples::
- >>> filters = torch.randn(33, 16, 3, 3, 3)
- >>> inputs = torch.randn(20, 16, 50, 10, 20)
- >>> F.conv3d(inputs, filters)
- """,
- ) # noqa: E501
- conv_transpose1d = _add_docstr(
- torch.conv_transpose1d,
- r"""
- conv_transpose1d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1) -> Tensor
- Applies a 1D transposed convolution operator over an input signal
- composed of several input planes, sometimes also called "deconvolution".
- {tf32_note}
- See :class:`~torch.nn.ConvTranspose1d` for details and output shape.
- Note:
- {cudnn_reproducibility_note}
- """.format(**reproducibility_notes, **tf32_notes)
- + r"""
- Args:
- input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iW)`
- weight: filters of shape :math:`(\text{in\_channels} , \frac{\text{out\_channels}}{\text{groups}} , kW)`
- bias: optional bias of shape :math:`(\text{out\_channels})`. Default: None
- stride: the stride of the convolving kernel. Can be a single number or a
- tuple ``(sW,)``. Default: 1
- padding: ``dilation * (kernel_size - 1) - padding`` zero-padding will be added to both
- sides of each dimension in the input. Can be a single number or a tuple
- ``(padW,)``. Default: 0
- output_padding: additional size added to one side of each dimension in the
- output shape. Can be a single number or a tuple ``(out_padW)``. Default: 0
- groups: split input into groups, :math:`\text{in\_channels}` should be divisible by the
- number of groups. Default: 1
- dilation: the spacing between kernel elements. Can be a single number or
- a tuple ``(dW,)``. Default: 1
- Examples::
- >>> inputs = torch.randn(20, 16, 50)
- >>> weights = torch.randn(16, 33, 5)
- >>> F.conv_transpose1d(inputs, weights)
- """,
- )
- conv_transpose2d = _add_docstr(
- torch.conv_transpose2d,
- r"""
- conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1) -> Tensor
- Applies a 2D transposed convolution operator over an input image
- composed of several input planes, sometimes also called "deconvolution".
- {tf32_note}
- See :class:`~torch.nn.ConvTranspose2d` for details and output shape.
- Note:
- {cudnn_reproducibility_note}
- """.format(**reproducibility_notes, **tf32_notes)
- + r"""
- Args:
- input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)`
- weight: filters of shape :math:`(\text{in\_channels} , \frac{\text{out\_channels}}{\text{groups}} , kH , kW)`
- bias: optional bias of shape :math:`(\text{out\_channels})`. Default: None
- stride: the stride of the convolving kernel. Can be a single number or a
- tuple ``(sH, sW)``. Default: 1
- padding: ``dilation * (kernel_size - 1) - padding`` zero-padding will be added to both
- sides of each dimension in the input. Can be a single number or a tuple
- ``(padH, padW)``. Default: 0
- output_padding: additional size added to one side of each dimension in the
- output shape. Can be a single number or a tuple ``(out_padH, out_padW)``.
- Default: 0
- groups: split input into groups, :math:`\text{in\_channels}` should be divisible by the
- number of groups. Default: 1
- dilation: the spacing between kernel elements. Can be a single number or
- a tuple ``(dH, dW)``. Default: 1
- Examples::
- >>> # With square kernels and equal stride
- >>> inputs = torch.randn(1, 4, 5, 5)
- >>> weights = torch.randn(4, 8, 3, 3)
- >>> F.conv_transpose2d(inputs, weights, padding=1)
- """,
- ) # noqa: E501
- conv_transpose3d = _add_docstr(
- torch.conv_transpose3d,
- r"""
- conv_transpose3d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1) -> Tensor
- Applies a 3D transposed convolution operator over an input image
- composed of several input planes, sometimes also called "deconvolution"
- {tf32_note}
- See :class:`~torch.nn.ConvTranspose3d` for details and output shape.
- Note:
- {cudnn_reproducibility_note}
- """.format(**reproducibility_notes, **tf32_notes)
- + r"""
- Args:
- input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iT , iH , iW)`
- weight: filters of shape :math:`(\text{in\_channels} , \frac{\text{out\_channels}}{\text{groups}} , kT , kH , kW)`
- bias: optional bias of shape :math:`(\text{out\_channels})`. Default: None
- stride: the stride of the convolving kernel. Can be a single number or a
- tuple ``(sT, sH, sW)``. Default: 1
- padding: ``dilation * (kernel_size - 1) - padding`` zero-padding will be added to both
- sides of each dimension in the input. Can be a single number or a tuple
- ``(padT, padH, padW)``. Default: 0
- output_padding: additional size added to one side of each dimension in the
- output shape. Can be a single number or a tuple
- ``(out_padT, out_padH, out_padW)``. Default: 0
- groups: split input into groups, :math:`\text{in\_channels}` should be divisible by the
- number of groups. Default: 1
- dilation: the spacing between kernel elements. Can be a single number or
- a tuple `(dT, dH, dW)`. Default: 1
- Examples::
- >>> inputs = torch.randn(20, 16, 50, 10, 20)
- >>> weights = torch.randn(16, 33, 3, 3, 3)
- >>> F.conv_transpose3d(inputs, weights)
- """,
- ) # noqa: E501
- conv_tbc = _add_docstr(
- torch.conv_tbc,
- r"""
- Applies a 1-dimensional sequence convolution over an input sequence.
- Input and output dimensions are (Time, Batch, Channels) - hence TBC.
- Args:
- input: input tensor of shape :math:`(\text{sequence length} \times batch \times \text{in\_channels})`
- weight: filter of shape (:math:`\text{kernel width} \times \text{in\_channels} \times \text{out\_channels}`)
- bias: bias of shape (:math:`\text{out\_channels}`)
- pad: number of timesteps to pad. Default: 0
- """,
- )
- # Pooling
- avg_pool1d = _add_docstr(
- torch.avg_pool1d,
- r"""
- avg_pool1d(input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True) -> Tensor
- Applies a 1D average pooling over an input signal composed of several
- input planes.
- See :class:`~torch.nn.AvgPool1d` for details and output shape.
- Args:
- input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iW)`
- kernel_size: the size of the window. Can be a single number or a
- tuple `(kW,)`
- stride: the stride of the window. Can be a single number or a tuple
- `(sW,)`. Default: :attr:`kernel_size`
- padding: implicit zero paddings on both sides of the input. Can be a single
- number or a tuple `(padW,)`. Should be at most half of effective kernel
- size, that is :math:`((kernelSize - 1) * dilation + 1) / 2`. Default: 0
- ceil_mode: when True, will use `ceil` instead of `floor` to compute the
- output shape. Default: ``False``
- count_include_pad: when True, will include the zero-padding in the
- averaging calculation. Default: ``True``
- Examples::
- >>> # pool of square window of size=3, stride=2
- >>> input = torch.tensor([[[1, 2, 3, 4, 5, 6, 7]]], dtype=torch.float32)
- >>> F.avg_pool1d(input, kernel_size=3, stride=2)
- tensor([[[ 2., 4., 6.]]])
- """,
- )
- avg_pool2d = _add_docstr(
- torch._C._nn.avg_pool2d,
- r"""
- avg_pool2d(input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None) -> Tensor
- Applies 2D average-pooling operation in :math:`kH \times kW` regions by step size
- :math:`sH \times sW` steps. The number of output features is equal to the number of
- input planes.
- See :class:`~torch.nn.AvgPool2d` for details and output shape.
- Args:
- input: input tensor :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)`
- kernel_size: size of the pooling region. Can be a single number, a single-element tuple or a
- tuple `(kH, kW)`
- stride: stride of the pooling operation. Can be a single number, a single-element tuple or a
- tuple `(sH, sW)`. Default: :attr:`kernel_size`
- padding: implicit zero paddings on both sides of the input. Can be a
- single number, a single-element tuple or a tuple `(padH, padW)`.
- Should be at most half of effective kernel size, that
- is :math:`((kernelSize - 1) * dilation + 1) / 2`. Default: 0
- ceil_mode: when True, will use `ceil` instead of `floor` in the formula
- to compute the output shape. Default: ``False``
- count_include_pad: when True, will include the zero-padding in the
- averaging calculation. Default: ``True``
- divisor_override: if specified, it will be used as divisor, otherwise
- size of the pooling region will be used. Default: None
- """,
- )
- avg_pool3d = _add_docstr(
- torch._C._nn.avg_pool3d,
- r"""
- avg_pool3d(input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None) -> Tensor
- Applies 3D average-pooling operation in :math:`kT \times kH \times kW` regions by step
- size :math:`sT \times sH \times sW` steps. The number of output features is equal to
- :math:`\lfloor\frac{\text{input planes}}{sT}\rfloor`.
- See :class:`~torch.nn.AvgPool3d` for details and output shape.
- Args:
- input: input tensor :math:`(\text{minibatch} , \text{in\_channels} , iT \times iH , iW)`
- kernel_size: size of the pooling region. Can be a single number or a
- tuple `(kT, kH, kW)`
- stride: stride of the pooling operation. Can be a single number or a
- tuple `(sT, sH, sW)`. Default: :attr:`kernel_size`
- padding: implicit zero paddings on both sides of the input. Can be a
- single number or a tuple `(padT, padH, padW)`. Should be at most half
- of effective kernel size, that is :math:`((kernelSize - 1) * dilation + 1) / 2`.
- Default: 0
- ceil_mode: when True, will use `ceil` instead of `floor` in the formula
- to compute the output shape
- count_include_pad: when True, will include the zero-padding in the
- averaging calculation
- divisor_override: if specified, it will be used as divisor, otherwise
- size of the pooling region will be used. Default: None
- """,
- )
- def fractional_max_pool2d_with_indices(
- input: Tensor,
- kernel_size: BroadcastingList2[int],
- output_size: Optional[BroadcastingList2[int]] = None,
- output_ratio: Optional[BroadcastingList2[float]] = None,
- return_indices: bool = False,
- _random_samples: Optional[Tensor] = None,
- ) -> tuple[Tensor, Tensor]: # noqa: D400
- r"""
- fractional_max_pool2d(input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None)
- Applies 2D fractional max pooling over an input signal composed of several input planes.
- Fractional MaxPooling is described in detail in the paper `Fractional MaxPooling`_ by Ben Graham
- The max-pooling operation is applied in :math:`kH \times kW` regions by a stochastic
- step size determined by the target output size.
- The number of output features is equal to the number of input planes.
- Args:
- kernel_size: the size of the window to take a max over.
- Can be a single number :math:`k` (for a square kernel of :math:`k \times k`)
- or a tuple `(kH, kW)`
- output_size: the target output size of the image of the form :math:`oH \times oW`.
- Can be a tuple `(oH, oW)` or a single number :math:`oH` for a square image :math:`oH \times oH`
- output_ratio: If one wants to have an output size as a ratio of the input size, this option can be given.
- This has to be a number or tuple in the range (0, 1)
- return_indices: if ``True``, will return the indices along with the outputs.
- Useful to pass to :func:`~torch.nn.functional.max_unpool2d`.
- Examples::
- >>> input = torch.randn(20, 16, 50, 32)
- >>> # pool of square window of size=3, and target output size 13x12
- >>> F.fractional_max_pool2d(input, 3, output_size=(13, 12))
- >>> # pool of square window and target output size being half of input image size
- >>> F.fractional_max_pool2d(input, 3, output_ratio=(0.5, 0.5))
- .. _Fractional MaxPooling:
- http://arxiv.org/abs/1412.6071
- """
- if has_torch_function_variadic(input, _random_samples):
- return handle_torch_function(
- fractional_max_pool2d_with_indices,
- (input, _random_samples),
- input,
- kernel_size,
- output_size=output_size,
- output_ratio=output_ratio,
- return_indices=return_indices,
- _random_samples=_random_samples,
- )
- if output_size is None and output_ratio is None:
- raise ValueError(
- "fractional_max_pool2d requires specifying either an output_size or an output_ratio"
- )
- if output_size is None:
- assert output_ratio is not None
- if len(output_ratio) > 2:
- raise ValueError(
- "fractional_max_pool2d requires output_ratio to either be a single Int or tuple of Ints."
- )
- _output_ratio = _pair(output_ratio)
- output_size = [
- int(input.size(-2) * _output_ratio[0]),
- int(input.size(-1) * _output_ratio[1]),
- ]
- if _random_samples is None:
- n_batch = 1 if input.dim() == 3 else input.size(0)
- _random_samples = torch.rand(
- n_batch, input.size(-3), 2, dtype=input.dtype, device=input.device
- )
- return torch._C._nn.fractional_max_pool2d(
- input, kernel_size, output_size, _random_samples
- )
- def _fractional_max_pool2d(
- input: Tensor,
- kernel_size: BroadcastingList2[int],
- output_size: Optional[BroadcastingList2[int]] = None,
- output_ratio: Optional[BroadcastingList2[float]] = None,
- return_indices: bool = False,
- _random_samples: Optional[Tensor] = None,
- ) -> Tensor:
- if has_torch_function_variadic(input, _random_samples):
- return handle_torch_function(
- fractional_max_pool2d,
- (input, _random_samples),
- input,
- kernel_size,
- output_size=output_size,
- output_ratio=output_ratio,
- return_indices=return_indices,
- _random_samples=_random_samples,
- )
- return fractional_max_pool2d_with_indices(
- input, kernel_size, output_size, output_ratio, return_indices, _random_samples
- )[0]
- fractional_max_pool2d = boolean_dispatch(
- arg_name="return_indices",
- arg_index=4,
- default=False,
- if_true=fractional_max_pool2d_with_indices,
- if_false=_fractional_max_pool2d,
- module_name=__name__,
- func_name="fractional_max_pool2d",
- )
- def fractional_max_pool3d_with_indices(
- input: Tensor,
- kernel_size: BroadcastingList3[int],
- output_size: Optional[BroadcastingList3[int]] = None,
- output_ratio: Optional[BroadcastingList3[float]] = None,
- return_indices: bool = False,
- _random_samples: Optional[Tensor] = None,
- ) -> tuple[Tensor, Tensor]: # noqa: D400
- r"""
- fractional_max_pool3d(input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None)
- Applies 3D fractional max pooling over an input signal composed of several input planes.
- Fractional MaxPooling is described in detail in the paper `Fractional MaxPooling`_ by Ben Graham
- The max-pooling operation is applied in :math:`kT \times kH \times kW` regions by a stochastic
- step size determined by the target output size.
- The number of output features is equal to the number of input planes.
- Args:
- kernel_size: the size of the window to take a max over.
- Can be a single number :math:`k` (for a square kernel of :math:`k \times k \times k`)
- or a tuple `(kT, kH, kW)`
- output_size: the target output size of the form :math:`oT \times oH \times oW`.
- Can be a tuple `(oT, oH, oW)` or a single number :math:`oH` for a cubic output
- :math:`oH \times oH \times oH`
- output_ratio: If one wants to have an output size as a ratio of the input size, this option can be given.
- This has to be a number or tuple in the range (0, 1)
- return_indices: if ``True``, will return the indices along with the outputs.
- Useful to pass to :func:`~torch.nn.functional.max_unpool3d`.
- Shape:
- - Input: :math:`(N, C, T_{in}, H_{in}, W_{in})` or :math:`(C, T_{in}, H_{in}, W_{in})`.
- - Output: :math:`(N, C, T_{out}, H_{out}, W_{out})` or :math:`(C, T_{out}, H_{out}, W_{out})`, where
- :math:`(T_{out}, H_{out}, W_{out})=\text{output\_size}` or
- :math:`(T_{out}, H_{out}, W_{out})=\text{output\_ratio} \times (T_{in}, H_{in}, W_{in})`
- Examples::
- >>> input = torch.randn(20, 16, 50, 32, 16)
- >>> # pool of cubic window of size=3, and target output size 13x12x11
- >>> F.fractional_max_pool3d(input, 3, output_size=(13, 12, 11))
- >>> # pool of cubic window and target output size being half of input size
- >>> F.fractional_max_pool3d(input, 3, output_ratio=(0.5, 0.5, 0.5))
- .. _Fractional MaxPooling:
- http://arxiv.org/abs/1412.6071
- """
- if has_torch_function_variadic(input, _random_samples):
- return handle_torch_function(
- fractional_max_pool3d_with_indices,
- (input, _random_samples),
- input,
- kernel_size,
- output_size=output_size,
- output_ratio=output_ratio,
- return_indices=return_indices,
- _random_samples=_random_samples,
- )
- if output_size is None and output_ratio is None:
- raise ValueError(
- "fractional_max_pool3d requires specifying either an output_size or an output_ratio"
- )
- if output_size is None:
- assert output_ratio is not None
- _output_ratio = _triple(output_ratio)
- output_size = [
- int(input.size(-3) * _output_ratio[0]),
- int(input.size(-2) * _output_ratio[1]),
- int(input.size(-1) * _output_ratio[2]),
- ]
- if _random_samples is None:
- n_batch = 1 if input.dim() == 4 else input.size(0)
- _random_samples = torch.rand(
- n_batch, input.size(-4), 3, dtype=input.dtype, device=input.device
- )
- return torch._C._nn.fractional_max_pool3d(
- input, kernel_size, output_size, _random_samples
- )
- def _fractional_max_pool3d(
- input: Tensor,
- kernel_size: BroadcastingList3[int],
- output_size: Optional[BroadcastingList3[int]] = None,
- output_ratio: Optional[BroadcastingList3[float]] = None,
- return_indices: bool = False,
- _random_samples: Optional[Tensor] = None,
- ) -> Tensor:
- if has_torch_function_variadic(input, _random_samples):
- return handle_torch_function(
- fractional_max_pool3d,
- (input, _random_samples),
- input,
- kernel_size,
- output_size=output_size,
- output_ratio=output_ratio,
- return_indices=return_indices,
- _random_samples=_random_samples,
- )
- return fractional_max_pool3d_with_indices(
- input, kernel_size, output_size, output_ratio, return_indices, _random_samples
- )[0]
- fractional_max_pool3d = boolean_dispatch(
- arg_name="return_indices",
- arg_index=4,
- default=False,
- if_true=fractional_max_pool3d_with_indices,
- if_false=_fractional_max_pool3d,
- module_name=__name__,
- func_name="fractional_max_pool3d",
- )
- def max_pool1d_with_indices(
- input: Tensor,
- kernel_size: BroadcastingList1[int],
- stride: Optional[BroadcastingList1[int]] = None,
- padding: BroadcastingList1[int] = 0,
- dilation: BroadcastingList1[int] = 1,
- ceil_mode: bool = False,
- return_indices: bool = False,
- ) -> tuple[Tensor, Tensor]: # noqa: D400
- r"""
- max_pool1d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False)
- Applies a 1D max pooling over an input signal composed of several input
- planes.
- .. note::
- The order of :attr:`ceil_mode` and :attr:`return_indices` is different from
- what seen in :class:`~torch.nn.MaxPool1d`, and will change in a future release.
- See :class:`~torch.nn.MaxPool1d` for details.
- Args:
- input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iW)`, minibatch dim optional.
- kernel_size: the size of the window. Can be a single number or a
- tuple `(kW,)`
- stride: the stride of the window. Can be a single number or a tuple
- `(sW,)`. Default: :attr:`kernel_size`
- padding: Implicit negative infinity padding to be added on both sides, must be >= 0 and <= kernel_size / 2.
- dilation: The stride between elements within a sliding window, must be > 0.
- ceil_mode: If ``True``, will use `ceil` instead of `floor` to compute the output shape. This
- ensures that every element in the input tensor is covered by a sliding window.
- return_indices: If ``True``, will return the argmax along with the max values.
- Useful for :class:`torch.nn.functional.max_unpool1d` later
- """
- if has_torch_function_unary(input):
- return handle_torch_function(
- max_pool1d_with_indices,
- (input,),
- input,
- kernel_size,
- stride=stride,
- padding=padding,
- dilation=dilation,
- ceil_mode=ceil_mode,
- return_indices=return_indices,
- )
- if stride is None:
- stride = torch.jit.annotate(list[int], [])
- return torch.max_pool1d_with_indices(
- input, kernel_size, stride, padding, dilation, ceil_mode
- )
- def _max_pool1d(
- input: Tensor,
- kernel_size: BroadcastingList1[int],
- stride: Optional[BroadcastingList1[int]] = None,
- padding: BroadcastingList1[int] = 0,
- dilation: BroadcastingList1[int] = 1,
- ceil_mode: bool = False,
- return_indices: bool = False,
- ) -> Tensor:
- if has_torch_function_unary(input):
- return handle_torch_function(
- max_pool1d,
- (input,),
- input,
- kernel_size,
- stride=stride,
- padding=padding,
- dilation=dilation,
- ceil_mode=ceil_mode,
- return_indices=return_indices,
- )
- if stride is None:
- stride = torch.jit.annotate(list[int], [])
- return torch.max_pool1d(input, kernel_size, stride, padding, dilation, ceil_mode)
- max_pool1d = boolean_dispatch(
- arg_name="return_indices",
- arg_index=6,
- default=False,
- if_true=max_pool1d_with_indices,
- if_false=_max_pool1d,
- module_name=__name__,
- func_name="max_pool1d",
- )
- def max_pool2d_with_indices(
- input: Tensor,
- kernel_size: BroadcastingList2[int],
- stride: Optional[BroadcastingList2[int]] = None,
- padding: BroadcastingList2[int] = 0,
- dilation: BroadcastingList2[int] = 1,
- ceil_mode: bool = False,
- return_indices: bool = False,
- ) -> tuple[Tensor, Tensor]: # noqa: D400
- r"""
- max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False)
- Applies a 2D max pooling over an input signal composed of several input
- planes.
- .. note::
- The order of :attr:`ceil_mode` and :attr:`return_indices` is different from
- what seen in :class:`~torch.nn.MaxPool2d`, and will change in a future release.
- See :class:`~torch.nn.MaxPool2d` for details.
- Args:
- input: input tensor :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)`, minibatch dim optional.
- kernel_size: size of the pooling region. Can be a single number or a
- tuple `(kH, kW)`
- stride: stride of the pooling operation. Can be a single number or a
- tuple `(sH, sW)`. Default: :attr:`kernel_size`
- padding: Implicit negative infinity padding to be added on both sides, must be >= 0 and <= kernel_size / 2.
- dilation: The stride between elements within a sliding window, must be > 0.
- ceil_mode: If ``True``, will use `ceil` instead of `floor` to compute the output shape. This
- ensures that every element in the input tensor is covered by a sliding window.
- return_indices: If ``True``, will return the argmax along with the max values.
- Useful for :class:`torch.nn.functional.max_unpool2d` later
- """
- if has_torch_function_unary(input):
- return handle_torch_function(
- max_pool2d_with_indices,
- (input,),
- input,
- kernel_size,
- stride=stride,
- padding=padding,
- dilation=dilation,
- ceil_mode=ceil_mode,
- return_indices=return_indices,
- )
- if stride is None:
- stride = torch.jit.annotate(list[int], [])
- return torch._C._nn.max_pool2d_with_indices(
- input, kernel_size, stride, padding, dilation, ceil_mode
- )
- def _max_pool2d(
- input: Tensor,
- kernel_size: BroadcastingList2[int],
- stride: Optional[BroadcastingList2[int]] = None,
- padding: BroadcastingList2[int] = 0,
- dilation: BroadcastingList2[int] = 1,
- ceil_mode: bool = False,
- return_indices: bool = False,
- ) -> Tensor:
- if has_torch_function_unary(input):
- return handle_torch_function(
- max_pool2d,
- (input,),
- input,
- kernel_size,
- stride=stride,
- padding=padding,
- dilation=dilation,
- ceil_mode=ceil_mode,
- return_indices=return_indices,
- )
- if stride is None:
- stride = torch.jit.annotate(list[int], [])
- return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
- max_pool2d = boolean_dispatch(
- arg_name="return_indices",
- arg_index=6,
- default=False,
- if_true=max_pool2d_with_indices,
- if_false=_max_pool2d,
- module_name=__name__,
- func_name="max_pool2d",
- )
- def max_pool3d_with_indices(
- input: Tensor,
- kernel_size: BroadcastingList3[int],
- stride: Optional[BroadcastingList3[int]] = None,
- padding: BroadcastingList3[int] = 0,
- dilation: BroadcastingList3[int] = 1,
- ceil_mode: bool = False,
- return_indices: bool = False,
- ) -> tuple[Tensor, Tensor]: # noqa: D400
- r"""
- max_pool3d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False)
- Applies a 3D max pooling over an input signal composed of several input
- planes.
- .. note::
- The order of :attr:`ceil_mode` and :attr:`return_indices` is different from
- what seen in :class:`~torch.nn.MaxPool3d`, and will change in a future release.
- See :class:`~torch.nn.MaxPool3d` for details.
- Args:
- input: input tensor :math:`(\text{minibatch} , \text{in\_channels} , iD, iH , iW)`, minibatch dim optional.
- kernel_size: size of the pooling region. Can be a single number or a
- tuple `(kT, kH, kW)`
- stride: stride of the pooling operation. Can be a single number or a
- tuple `(sT, sH, sW)`. Default: :attr:`kernel_size`
- padding: Implicit negative infinity padding to be added on both sides, must be >= 0 and <= kernel_size / 2.
- dilation: The stride between elements within a sliding window, must be > 0.
- ceil_mode: If ``True``, will use `ceil` instead of `floor` to compute the output shape. This
- ensures that every element in the input tensor is covered by a sliding window.
- return_indices: If ``True``, will return the argmax along with the max values.
- Useful for :class:`torch.nn.functional.max_unpool3d` later
- """
- if has_torch_function_unary(input):
- return handle_torch_function(
- max_pool3d_with_indices,
- (input,),
- input,
- kernel_size,
- stride=stride,
- padding=padding,
- dilation=dilation,
- ceil_mode=ceil_mode,
- return_indices=return_indices,
- )
- if stride is None:
- stride = torch.jit.annotate(list[int], [])
- return torch._C._nn.max_pool3d_with_indices(
- input, kernel_size, stride, padding, dilation, ceil_mode
- )
- def _max_pool3d(
- input: Tensor,
- kernel_size: BroadcastingList3[int],
- stride: Optional[BroadcastingList3[int]] = None,
- padding: BroadcastingList3[int] = 0,
- dilation: BroadcastingList3[int] = 1,
- ceil_mode: bool = False,
- return_indices: bool = False,
- ) -> Tensor:
- if has_torch_function_unary(input):
- return handle_torch_function(
- max_pool3d,
- (input,),
- input,
- kernel_size,
- stride=stride,
- padding=padding,
- dilation=dilation,
- ceil_mode=ceil_mode,
- return_indices=return_indices,
- )
- if stride is None:
- stride = torch.jit.annotate(list[int], [])
- return torch.max_pool3d(input, kernel_size, stride, padding, dilation, ceil_mode)
- max_pool3d = boolean_dispatch(
- arg_name="return_indices",
- arg_index=6,
- default=False,
- if_true=max_pool3d_with_indices,
- if_false=_max_pool3d,
- module_name=__name__,
- func_name="max_pool3d",
- )
- def _unpool_output_size(
- input: Tensor,
- kernel_size: list[int],
- stride: list[int],
- padding: list[int],
- output_size: Optional[list[int]],
- ) -> list[int]:
- input_size = input.size()
- default_size = torch.jit.annotate(list[int], [])
- for d in range(len(kernel_size)):
- default_size.append(
- (input_size[-len(kernel_size) + d] - 1) * stride[d]
- + kernel_size[d]
- - 2 * padding[d]
- )
- if output_size is None:
- ret = default_size
- else:
- if len(output_size) == len(kernel_size) + 2:
- output_size = output_size[2:]
- if len(output_size) != len(kernel_size):
- raise ValueError(
- "output_size should be a sequence containing "
- f"{len(kernel_size)} or {len(kernel_size) + 2} elements, but it has a length of '{len(output_size)}'"
- )
- for d in range(len(kernel_size)):
- min_size = default_size[d] - stride[d]
- max_size = default_size[d] + stride[d]
- if not (min_size < output_size[d] < max_size):
- raise ValueError(
- f'invalid output_size "{output_size}" (dim {d} must be between {min_size} and {max_size})'
- )
- ret = output_size
- return ret
- def max_unpool1d(
- input: Tensor,
- indices: Tensor,
- kernel_size: BroadcastingList1[int],
- stride: Optional[BroadcastingList1[int]] = None,
- padding: BroadcastingList1[int] = 0,
- output_size: Optional[BroadcastingList1[int]] = None,
- ) -> Tensor:
- r"""Compute a partial inverse of :class:`MaxPool1d`.
- See :class:`~torch.nn.MaxUnpool1d` for details.
- """
- if has_torch_function_unary(input):
- return handle_torch_function(
- max_unpool1d,
- (input,),
- input,
- indices,
- kernel_size,
- stride=stride,
- padding=padding,
- output_size=output_size,
- )
- kernel_size = _single(kernel_size)
- if stride is not None:
- _stride = _single(stride)
- else:
- _stride = kernel_size
- padding = _single(padding)
- output_size = _unpool_output_size(input, kernel_size, _stride, padding, output_size)
- if isinstance(output_size, list):
- output_size = output_size + [1]
- else:
- output_size = output_size + (1,)
- return torch._C._nn.max_unpool2d(
- input.unsqueeze(-1), indices.unsqueeze(-1), output_size
- ).squeeze(-1)
- def max_unpool2d(
- input: Tensor,
- indices: Tensor,
- kernel_size: BroadcastingList2[int],
- stride: Optional[BroadcastingList2[int]] = None,
- padding: BroadcastingList2[int] = 0,
- output_size: Optional[BroadcastingList2[int]] = None,
- ) -> Tensor:
- r"""Compute a partial inverse of :class:`MaxPool2d`.
- See :class:`~torch.nn.MaxUnpool2d` for details.
- """
- if has_torch_function_unary(input):
- return handle_torch_function(
- max_unpool2d,
- (input,),
- input,
- indices,
- kernel_size,
- stride=stride,
- padding=padding,
- output_size=output_size,
- )
- kernel_size = _pair(kernel_size)
- if stride is not None:
- _stride = _pair(stride)
- else:
- _stride = kernel_size
- padding = _pair(padding)
- output_size = _unpool_output_size(input, kernel_size, _stride, padding, output_size)
- return torch._C._nn.max_unpool2d(input, indices, output_size)
- def max_unpool3d(
- input: Tensor,
- indices: Tensor,
- kernel_size: BroadcastingList3[int],
- stride: Optional[BroadcastingList3[int]] = None,
- padding: BroadcastingList3[int] = 0,
- output_size: Optional[BroadcastingList3[int]] = None,
- ) -> Tensor:
- r"""Compute a partial inverse of :class:`MaxPool3d`.
- See :class:`~torch.nn.MaxUnpool3d` for details.
- """
- if has_torch_function_unary(input):
- return handle_torch_function(
- max_unpool3d,
- (input,),
- input,
- indices,
- kernel_size,
- stride=stride,
- padding=padding,
- output_size=output_size,
- )
- kernel_size = _triple(kernel_size)
- if stride is not None:
- _stride = _triple(stride)
- else:
- _stride = kernel_size
- padding = _triple(padding)
- output_size = _unpool_output_size(input, kernel_size, _stride, padding, output_size)
- return torch._C._nn.max_unpool3d(input, indices, output_size, _stride, padding)
- def lp_pool3d(
- input: Tensor,
- norm_type: Union[int, float],
- kernel_size: BroadcastingList3[int],
- stride: Optional[BroadcastingList3[int]] = None,
- ceil_mode: bool = False,
- ) -> Tensor:
- r"""
- Apply a 3D power-average pooling over an input signal composed of several input planes.
- If the sum of all inputs to the power of `p` is
- zero, the gradient is set to zero as well.
- See :class:`~torch.nn.LPPool3d` for details.
- """
- if has_torch_function_unary(input):
- return handle_torch_function(
- lp_pool3d,
- (input,),
- input,
- norm_type,
- kernel_size,
- stride=stride,
- ceil_mode=ceil_mode,
- )
- kd, kw, kh = _triple(kernel_size)
- if stride is not None:
- out = avg_pool3d(input.pow(norm_type), kernel_size, stride, 0, ceil_mode)
- else:
- out = avg_pool3d(
- input.pow(norm_type), kernel_size, padding=0, ceil_mode=ceil_mode
- )
- return (
- (torch.sign(out) * relu(torch.abs(out))).mul(kd * kw * kh).pow(1.0 / norm_type)
- )
- def lp_pool2d(
- input: Tensor,
- norm_type: Union[int, float],
- kernel_size: BroadcastingList2[int],
- stride: Optional[BroadcastingList2[int]] = None,
- ceil_mode: bool = False,
- ) -> Tensor:
- r"""
- Apply a 2D power-average pooling over an input signal composed of several input planes.
- If the sum of all inputs to the power of `p` is
- zero, the gradient is set to zero as well.
- See :class:`~torch.nn.LPPool2d` for details.
- """
- if has_torch_function_unary(input):
- return handle_torch_function(
- lp_pool2d,
- (input,),
- input,
- norm_type,
- kernel_size,
- stride=stride,
- ceil_mode=ceil_mode,
- )
- kw, kh = _pair(kernel_size)
- if stride is not None:
- out = avg_pool2d(input.pow(norm_type), kernel_size, stride, 0, ceil_mode)
- else:
- out = avg_pool2d(
- input.pow(norm_type), kernel_size, padding=0, ceil_mode=ceil_mode
- )
- return (torch.sign(out) * relu(torch.abs(out))).mul(kw * kh).pow(1.0 / norm_type)
- def lp_pool1d(
- input: Tensor,
- norm_type: Union[int, float],
- kernel_size: int,
- stride: Optional[BroadcastingList1[int]] = None,
- ceil_mode: bool = False,
- ) -> Tensor:
- r"""Apply a 1D power-average pooling over an input signal composed of several input planes.
- If the sum of all inputs to the power of `p` is
- zero, the gradient is set to zero as well.
- See :class:`~torch.nn.LPPool1d` for details.
- """
- if has_torch_function_unary(input):
- return handle_torch_function(
- lp_pool1d,
- (input,),
- input,
- norm_type,
- kernel_size,
- stride=stride,
- ceil_mode=ceil_mode,
- )
- if stride is not None:
- out = avg_pool1d(input.pow(norm_type), kernel_size, stride, 0, ceil_mode)
- else:
- out = avg_pool1d(
- input.pow(norm_type), kernel_size, padding=0, ceil_mode=ceil_mode
- )
- return (
- (torch.sign(out) * relu(torch.abs(out))).mul(kernel_size).pow(1.0 / norm_type)
- )
- def adaptive_max_pool1d_with_indices(
- input: Tensor,
- output_size: BroadcastingList1[int],
- return_indices: bool = False,
- ) -> tuple[Tensor, Tensor]: # noqa: D400
- r"""
- adaptive_max_pool1d(input, output_size, return_indices=False)
- Applies a 1D adaptive max pooling over an input signal composed of
- several input planes.
- See :class:`~torch.nn.AdaptiveMaxPool1d` for details and output shape.
- Args:
- output_size: the target output size (single integer)
- return_indices: whether to return pooling indices. Default: ``False``
- """
- if has_torch_function_unary(input):
- return handle_torch_function(
- adaptive_max_pool1d_with_indices,
- (input,),
- input,
- output_size,
- return_indices=return_indices,
- )
- return torch.adaptive_max_pool1d(input, output_size)
- def _adaptive_max_pool1d(
- input: Tensor,
- output_size: BroadcastingList1[int],
- return_indices: bool = False,
- ) -> Tensor:
- if has_torch_function_unary(input):
- return handle_torch_function(
- adaptive_max_pool1d,
- (input,),
- input,
- output_size,
- return_indices=return_indices,
- )
- return adaptive_max_pool1d_with_indices(input, output_size)[0]
- adaptive_max_pool1d = boolean_dispatch(
- arg_name="return_indices",
- arg_index=2,
- default=False,
- if_true=adaptive_max_pool1d_with_indices,
- if_false=_adaptive_max_pool1d,
- module_name=__name__,
- func_name="adaptive_max_pool1d",
- )
- def adaptive_max_pool2d_with_indices(
- input: Tensor,
- output_size: BroadcastingList2[int],
- return_indices: bool = False,
- ) -> tuple[Tensor, Tensor]: # noqa: D400
- r"""adaptive_max_pool2d(input, output_size, return_indices=False)
- Applies a 2D adaptive max pooling over an input signal composed of
- several input planes.
- See :class:`~torch.nn.AdaptiveMaxPool2d` for details and output shape.
- Args:
- output_size: the target output size (single integer or
- double-integer tuple)
- return_indices: whether to return pooling indices. Default: ``False``
- """
- if has_torch_function_unary(input):
- return handle_torch_function(
- adaptive_max_pool2d_with_indices,
- (input,),
- input,
- output_size,
- return_indices=return_indices,
- )
- output_size = _list_with_default(output_size, input.size())
- return torch._C._nn.adaptive_max_pool2d(input, output_size)
- def _adaptive_max_pool2d(
- input: Tensor,
- output_size: BroadcastingList2[int],
- return_indices: bool = False,
- ) -> Tensor:
- if has_torch_function_unary(input):
- return handle_torch_function(
- adaptive_max_pool2d,
- (input,),
- input,
- output_size,
- return_indices=return_indices,
- )
- return adaptive_max_pool2d_with_indices(input, output_size)[0]
- adaptive_max_pool2d = boolean_dispatch(
- arg_name="return_indices",
- arg_index=2,
- default=False,
- if_true=adaptive_max_pool2d_with_indices,
- if_false=_adaptive_max_pool2d,
- module_name=__name__,
- func_name="adaptive_max_pool2d",
- )
- def adaptive_max_pool3d_with_indices(
- input: Tensor,
- output_size: BroadcastingList3[int],
- return_indices: bool = False,
- ) -> tuple[Tensor, Tensor]: # noqa: D400
- r"""
- adaptive_max_pool3d(input, output_size, return_indices=False)
- Applies a 3D adaptive max pooling over an input signal composed of
- several input planes.
- See :class:`~torch.nn.AdaptiveMaxPool3d` for details and output shape.
- Args:
- output_size: the target output size (single integer or
- triple-integer tuple)
- return_indices: whether to return pooling indices. Default: ``False``
- """
- if has_torch_function_unary(input):
- return handle_torch_function(
- adaptive_max_pool3d_with_indices,
- (input,),
- input,
- output_size,
- return_indices=return_indices,
- )
- output_size = _list_with_default(output_size, input.size())
- return torch._C._nn.adaptive_max_pool3d(input, output_size)
- def _adaptive_max_pool3d(
- input: Tensor,
- output_size: BroadcastingList3[int],
- return_indices: bool = False,
- ) -> Tensor:
- if has_torch_function_unary(input):
- return handle_torch_function(
- adaptive_max_pool3d,
- (input,),
- input,
- output_size,
- return_indices=return_indices,
- )
- return adaptive_max_pool3d_with_indices(input, output_size)[0]
- adaptive_max_pool3d = boolean_dispatch(
- arg_name="return_indices",
- arg_index=2,
- default=False,
- if_true=adaptive_max_pool3d_with_indices,
- if_false=_adaptive_max_pool3d,
- module_name=__name__,
- func_name="adaptive_max_pool3d",
- )
- adaptive_avg_pool1d = _add_docstr(
- torch.adaptive_avg_pool1d,
- r"""
- adaptive_avg_pool1d(input, output_size) -> Tensor
- Applies a 1D adaptive average pooling over an input signal composed of
- several input planes.
- See :class:`~torch.nn.AdaptiveAvgPool1d` for details and output shape.
- Args:
- output_size: the target output size (single integer)
- """,
- )
- def adaptive_avg_pool2d(input: Tensor, output_size: BroadcastingList2[int]) -> Tensor:
- r"""Apply a 2D adaptive average pooling over an input signal composed of several input planes.
- See :class:`~torch.nn.AdaptiveAvgPool2d` for details and output shape.
- Args:
- output_size: the target output size (single integer or
- double-integer tuple)
- """
- if has_torch_function_unary(input):
- return handle_torch_function(adaptive_avg_pool2d, (input,), input, output_size)
- _output_size = _list_with_default(output_size, input.size())
- return torch._C._nn.adaptive_avg_pool2d(input, _output_size)
- def adaptive_avg_pool3d(input: Tensor, output_size: BroadcastingList3[int]) -> Tensor:
- r"""Apply a 3D adaptive average pooling over an input signal composed of several input planes.
- See :class:`~torch.nn.AdaptiveAvgPool3d` for details and output shape.
- Args:
- output_size: the target output size (single integer or
- triple-integer tuple)
- """
- if has_torch_function_unary(input):
- return handle_torch_function(adaptive_avg_pool3d, (input,), input, output_size)
- _output_size = _list_with_default(output_size, input.size())
- return torch._C._nn.adaptive_avg_pool3d(input, _output_size)
- # Activation functions
- def dropout(
- input: Tensor,
- p: float = 0.5,
- training: bool = True,
- inplace: bool = False,
- ) -> Tensor:
- r"""During training, randomly zeroes some elements of the input tensor with probability :attr:`p`.
- Uses samples from a Bernoulli distribution.
- See :class:`~torch.nn.Dropout` for details.
- Args:
- p: probability of an element to be zeroed. Default: 0.5
- training: apply dropout if is ``True``. Default: ``True``
- inplace: If set to ``True``, will do this operation in-place. Default: ``False``
- """
- if has_torch_function_unary(input):
- return handle_torch_function(
- dropout, (input,), input, p=p, training=training, inplace=inplace
- )
- if p < 0.0 or p > 1.0:
- raise ValueError(f"dropout probability has to be between 0 and 1, but got {p}")
- return (
- _VF.dropout_(input, p, training) if inplace else _VF.dropout(input, p, training)
- )
- def alpha_dropout(
- input: Tensor,
- p: float = 0.5,
- training: bool = False,
- inplace: bool = False,
- ) -> Tensor:
- r"""Apply alpha dropout to the input.
- See :class:`~torch.nn.AlphaDropout` for details.
- """
- if has_torch_function_unary(input):
- return handle_torch_function(
- alpha_dropout, (input,), input, p=p, training=training, inplace=inplace
- )
- if p < 0.0 or p > 1.0:
- raise ValueError(f"dropout probability has to be between 0 and 1, but got {p}")
- return (
- _VF.alpha_dropout_(input, p, training)
- if inplace
- else _VF.alpha_dropout(input, p, training)
- )
- def dropout1d(
- input: Tensor,
- p: float = 0.5,
- training: bool = True,
- inplace: bool = False,
- ) -> Tensor:
- r"""Randomly zero out entire channels (a channel is a 1D feature map).
- For example, the :math:`j`-th channel of the :math:`i`-th sample in the
- batched input is a 1D tensor :math:`\text{input}[i, j]` of the input tensor.
- Each channel will be zeroed out independently on every forward call with
- probability :attr:`p` using samples from a Bernoulli distribution.
- See :class:`~torch.nn.Dropout1d` for details.
- Args:
- p: probability of a channel to be zeroed. Default: 0.5
- training: apply dropout if is ``True``. Default: ``True``
- inplace: If set to ``True``, will do this operation in-place. Default: ``False``
- """
- if has_torch_function_unary(input):
- return handle_torch_function(
- dropout1d, (input,), input, p=p, training=training, inplace=inplace
- )
- if p < 0.0 or p > 1.0:
- raise ValueError(f"dropout probability has to be between 0 and 1, but got {p}")
- inp_dim = input.dim()
- if inp_dim not in (2, 3):
- raise RuntimeError(
- f"dropout1d: Expected 2D or 3D input, but received a {inp_dim}D input. "
- "Note that dropout1d exists to provide channel-wise dropout on inputs with 1 "
- "spatial dimension, a channel dimension, and an optional batch dimension "
- "(i.e. 2D or 3D inputs)."
- )
- is_batched = inp_dim == 3
- if not is_batched:
- input = input.unsqueeze_(0) if inplace else input.unsqueeze(0)
- result = (
- _VF.feature_dropout_(input, p, training)
- if inplace
- else _VF.feature_dropout(input, p, training)
- )
- if not is_batched:
- result = result.squeeze_(0) if inplace else result.squeeze(0)
- return result
- def dropout2d(
- input: Tensor,
- p: float = 0.5,
- training: bool = True,
- inplace: bool = False,
- ) -> Tensor:
- r"""Randomly zero out entire channels (a channel is a 2D feature map).
- For example, the :math:`j`-th channel of the :math:`i`-th sample in the
- batched input is a 2D tensor :math:`\text{input}[i, j]` of the input tensor.
- Each channel will be zeroed out independently on every forward call with
- probability :attr:`p` using samples from a Bernoulli distribution.
- See :class:`~torch.nn.Dropout2d` for details.
- Args:
- p: probability of a channel to be zeroed. Default: 0.5
- training: apply dropout if is ``True``. Default: ``True``
- inplace: If set to ``True``, will do this operation in-place. Default: ``False``
- """
- if has_torch_function_unary(input):
- return handle_torch_function(
- dropout2d, (input,), input, p=p, training=training, inplace=inplace
- )
- if p < 0.0 or p > 1.0:
- raise ValueError(f"dropout probability has to be between 0 and 1, but got {p}")
- inp_dim = input.dim()
- if inp_dim not in (3, 4):
- warn_msg = (
- f"dropout2d: Received a {inp_dim}-D input to dropout2d, which is deprecated "
- "and will result in an error in a future release. To retain the behavior "
- "and silence this warning, please use dropout instead. Note that dropout2d "
- "exists to provide channel-wise dropout on inputs with 2 spatial dimensions, "
- "a channel dimension, and an optional batch dimension (i.e. 3D or 4D inputs)."
- )
- warnings.warn(warn_msg)
- # TODO: Properly support no-batch-dim inputs. For now, these are NOT supported; passing
- # a 3D input will perform dropout1d behavior instead. This was done historically and the
- # behavior is maintained here for now.
- # See https://github.com/pytorch/pytorch/issues/77081
- if inp_dim == 3:
- warnings.warn(
- "dropout2d: Received a 3D input to dropout2d and assuming that channel-wise "
- "1D dropout behavior is desired - input is interpreted as shape (N, C, L), where C "
- "is the channel dim. This behavior will change in a future release to interpret the "
- "input as one without a batch dimension, i.e. shape (C, H, W). To maintain the 1D "
- "channel-wise dropout behavior, please switch to using dropout1d instead."
- )
- result = (
- _VF.feature_dropout_(input, p, training)
- if inplace
- else _VF.feature_dropout(input, p, training)
- )
- return result
- def dropout3d(
- input: Tensor,
- p: float = 0.5,
- training: bool = True,
- inplace: bool = False,
- ) -> Tensor:
- r"""Randomly zero out entire channels (a channel is a 3D feature map).
- For example, the :math:`j`-th channel of the :math:`i`-th sample in the
- batched input is a 3D tensor :math:`\text{input}[i, j]` of the input tensor.
- Each channel will be zeroed out independently on every forward call with
- probability :attr:`p` using samples from a Bernoulli distribution.
- See :class:`~torch.nn.Dropout3d` for details.
- Args:
- p: probability of a channel to be zeroed. Default: 0.5
- training: apply dropout if is ``True``. Default: ``True``
- inplace: If set to ``True``, will do this operation in-place. Default: ``False``
- """
- if has_torch_function_unary(input):
- return handle_torch_function(
- dropout3d, (input,), input, p=p, training=training, inplace=inplace
- )
- if p < 0.0 or p > 1.0:
- raise ValueError(f"dropout probability has to be between 0 and 1, but got {p}")
- inp_dim = input.dim()
- if inp_dim not in (4, 5):
- warn_msg = (
- f"dropout3d: Received a {inp_dim}-D input to dropout3d, which is deprecated "
- "and will result in an error in a future release. To retain the behavior "
- "and silence this warning, please use dropout instead. Note that dropout3d "
- "exists to provide channel-wise dropout on inputs with 3 spatial dimensions, "
- "a channel dimension, and an optional batch dimension (i.e. 4D or 5D inputs)."
- )
- warnings.warn(warn_msg)
- is_batched = inp_dim == 5
- if not is_batched:
- input = input.unsqueeze_(0) if inplace else input.unsqueeze(0)
- result = (
- _VF.feature_dropout_(input, p, training)
- if inplace
- else _VF.feature_dropout(input, p, training)
- )
- if not is_batched:
- result = result.squeeze_(0) if inplace else result.squeeze(0)
- return result
- def feature_alpha_dropout(
- input: Tensor,
- p: float = 0.5,
- training: bool = False,
- inplace: bool = False,
- ) -> Tensor:
- r"""Randomly masks out entire channels (a channel is a feature map).
- For example, the :math:`j`-th channel of the :math:`i`-th sample in the batch input
- is a tensor :math:`\text{input}[i, j]` of the input tensor. Instead of
- setting activations to zero, as in regular Dropout, the activations are set
- to the negative saturation value of the SELU activation function.
- Each element will be masked independently on every forward call with
- probability :attr:`p` using samples from a Bernoulli distribution.
- The elements to be masked are randomized on every forward call, and scaled
- and shifted to maintain zero mean and unit variance.
- See :class:`~torch.nn.FeatureAlphaDropout` for details.
- Args:
- p: dropout probability of a channel to be zeroed. Default: 0.5
- training: apply dropout if is ``True``. Default: ``True``
- inplace: If set to ``True``, will do this operation in-place. Default: ``False``
- """
- if has_torch_function_unary(input):
- return handle_torch_function(
- feature_alpha_dropout,
- (input,),
- input,
- p=p,
- training=training,
- inplace=inplace,
- )
- if p < 0.0 or p > 1.0:
- raise ValueError(f"dropout probability has to be between 0 and 1, but got {p}")
- return (
- _VF.feature_alpha_dropout_(input, p, training)
- if inplace
- else _VF.feature_alpha_dropout(input, p, training)
- )
- def _threshold(
- input: Tensor,
- threshold: float,
- value: float,
- inplace: bool = False,
- ) -> Tensor:
- r"""Apply a threshold to each element of the input Tensor.
- See :class:`~torch.nn.Threshold` for more details.
- """
- if has_torch_function_unary(input):
- return handle_torch_function(
- _threshold, (input,), input, threshold, value, inplace=inplace
- )
- if inplace:
- result = _VF.threshold_(input, threshold, value)
- else:
- result = _VF.threshold(input, threshold, value)
- return result
- # We define this function as _threshold because it takes an argument
- # named threshold, which clobbers the recursive reference to the
- # function needed for __torch_function__ support
- threshold = _threshold
- threshold_ = _add_docstr(
- _VF.threshold_,
- r"""
- threshold_(input, threshold, value) -> Tensor
- In-place version of :func:`~threshold`.
- """,
- )
- def relu(input: Tensor, inplace: bool = False) -> Tensor: # noqa: D400,D402
- r"""relu(input, inplace=False) -> Tensor
- Applies the rectified linear unit function element-wise. See
- :class:`~torch.nn.ReLU` for more details.
- """
- if has_torch_function_unary(input):
- return handle_torch_function(relu, (input,), input, inplace=inplace)
- if inplace:
- result = torch.relu_(input)
- else:
- result = torch.relu(input)
- return result
- relu_ = _add_docstr(
- torch.relu_,
- r"""
- relu_(input) -> Tensor
- In-place version of :func:`~relu`.
- """,
- )
- def glu(input: Tensor, dim: int = -1) -> Tensor: # noqa: D400,D402
- r"""
- glu(input, dim=-1) -> Tensor
- The gated linear unit. Computes:
- .. math ::
- \text{GLU}(a, b) = a \otimes \sigma(b)
- where `input` is split in half along `dim` to form `a` and `b`, :math:`\sigma`
- is the sigmoid function and :math:`\otimes` is the element-wise product between matrices.
- See `Language Modeling with Gated Convolutional Networks <https://arxiv.org/abs/1612.08083>`_.
- Args:
- input (Tensor): input tensor
- dim (int): dimension on which to split the input. Default: -1
- """
- if has_torch_function_unary(input):
- return handle_torch_function(glu, (input,), input, dim=dim)
- if input.dim() == 0:
- raise RuntimeError(
- "glu does not support scalars because halving size must be even"
- )
- return torch._C._nn.glu(input, dim)
- def hardtanh(
- input: Tensor,
- min_val: float = -1.0,
- max_val: float = 1.0,
- inplace: bool = False,
- ) -> Tensor: # noqa: D400,D402
- r"""
- hardtanh(input, min_val=-1., max_val=1., inplace=False) -> Tensor
- Applies the HardTanh function element-wise. See :class:`~torch.nn.Hardtanh` for more
- details.
- """
- if has_torch_function_unary(input):
- return handle_torch_function(
- hardtanh, (input,), input, min_val=min_val, max_val=max_val, inplace=inplace
- )
- if min_val > max_val:
- raise ValueError("min_val cannot be greater than max_val")
- if inplace:
- result = torch._C._nn.hardtanh_(input, min_val, max_val)
- else:
- result = torch._C._nn.hardtanh(input, min_val, max_val)
- return result
- hardtanh_ = _add_docstr(
- torch._C._nn.hardtanh_,
- r"""
- hardtanh_(input, min_val=-1., max_val=1.) -> Tensor
- In-place version of :func:`~hardtanh`.
- """,
- )
- def relu6(input: Tensor, inplace: bool = False) -> Tensor: # noqa: D400,D402
- r"""relu6(input, inplace=False) -> Tensor
- Applies the element-wise function :math:`\text{ReLU6}(x) = \min(\max(0,x), 6)`.
- See :class:`~torch.nn.ReLU6` for more details.
- """
- if has_torch_function_unary(input):
- return handle_torch_function(relu6, (input,), input, inplace=inplace)
- if inplace:
- result = torch._C._nn.relu6_(input)
- else:
- result = torch._C._nn.relu6(input)
- return result
- def elu(input: Tensor, alpha: float = 1.0, inplace: bool = False) -> Tensor:
- r"""Apply the Exponential Linear Unit (ELU) function element-wise.
- See :class:`~torch.nn.ELU` for more details.
- """
- if has_torch_function_unary(input):
- return handle_torch_function(elu, (input,), input, alpha=alpha, inplace=inplace)
- if inplace:
- result = torch._C._nn.elu_(input, alpha)
- else:
- result = torch._C._nn.elu(input, alpha)
- return result
- elu_ = _add_docstr(
- torch._C._nn.elu_,
- r"""
- elu_(input, alpha=1.) -> Tensor
- In-place version of :func:`~elu`.
- """,
- )
- def selu(input: Tensor, inplace: bool = False) -> Tensor: # noqa: D400,D402
- r"""selu(input, inplace=False) -> Tensor
- Applies element-wise,
- :math:`\text{SELU}(x) = scale * (\max(0,x) + \min(0, \alpha * (\exp(x) - 1)))`,
- with :math:`\alpha=1.6732632423543772848170429916717` and
- :math:`scale=1.0507009873554804934193349852946`.
- See :class:`~torch.nn.SELU` for more details.
- """
- if has_torch_function_unary(input):
- return handle_torch_function(selu, (input,), input, inplace=inplace)
- if inplace:
- result = torch.selu_(input)
- else:
- result = torch.selu(input)
- return result
- selu_ = _add_docstr(
- torch.selu_,
- r"""
- selu_(input) -> Tensor
- In-place version of :func:`~selu`.
- """,
- )
- def celu(
- input: Tensor,
- alpha: float = 1.0,
- inplace: bool = False,
- ) -> Tensor: # noqa: D400,D402
- r"""celu(input, alpha=1., inplace=False) -> Tensor
- Applies element-wise,
- :math:`\text{CELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x/\alpha) - 1))`.
- See :class:`~torch.nn.CELU` for more details.
- """
- if has_torch_function_unary(input):
- return handle_torch_function(
- celu, (input,), input, alpha=alpha, inplace=inplace
- )
- if inplace:
- result = torch.celu_(input, alpha)
- else:
- result = torch.celu(input, alpha)
- return result
- celu_ = _add_docstr(
- torch.celu_,
- r"""
- celu_(input, alpha=1.) -> Tensor
- In-place version of :func:`~celu`.
- """,
- )
- def leaky_relu(
- input: Tensor,
- negative_slope: float = 0.01,
- inplace: bool = False,
- ) -> Tensor: # noqa: D400,D402
- r"""
- leaky_relu(input, negative_slope=0.01, inplace=False) -> Tensor
- Applies element-wise,
- :math:`\text{LeakyReLU}(x) = \max(0, x) + \text{negative\_slope} * \min(0, x)`
- See :class:`~torch.nn.LeakyReLU` for more details.
- """
- if has_torch_function_unary(input):
- return handle_torch_function(
- leaky_relu, (input,), input, negative_slope=negative_slope, inplace=inplace
- )
- if inplace:
- result = torch._C._nn.leaky_relu_(input, negative_slope)
- else:
- result = torch._C._nn.leaky_relu(input, negative_slope)
- return result
- leaky_relu_ = _add_docstr(
- torch._C._nn.leaky_relu_,
- r"""
- leaky_relu_(input, negative_slope=0.01) -> Tensor
- In-place version of :func:`~leaky_relu`.
- """,
- )
- prelu = _add_docstr(
- torch.prelu,
- r"""prelu(input, weight) -> Tensor
- Applies element-wise the function
- :math:`\text{PReLU}(x) = \max(0,x) + \text{weight} * \min(0,x)` where weight is a
- learnable parameter.
- .. note::
- `weight` is expected to be a scalar or 1-D tensor. If `weight` is 1-D,
- its size must match the number of input channels, determined by
- `input.size(1)` when `input.dim() >= 2`, otherwise 1.
- In the 1-D case, note that when `input` has dim > 2, `weight` can be expanded
- to the shape of `input` in a way that is not possible using normal
- :ref:`broadcasting semantics<broadcasting-semantics>`.
- See :class:`~torch.nn.PReLU` for more details.
- """,
- )
- def rrelu(
- input: Tensor,
- lower: float = 1.0 / 8,
- upper: float = 1.0 / 3,
- training: bool = False,
- inplace: bool = False,
- ) -> Tensor: # noqa: D400,D402
- r"""rrelu(input, lower=1./8, upper=1./3, training=False, inplace=False) -> Tensor
- Randomized leaky ReLU.
- See :class:`~torch.nn.RReLU` for more details.
- """
- if has_torch_function_unary(input):
- return handle_torch_function(
- rrelu,
- (input,),
- input,
- lower=lower,
- upper=upper,
- training=training,
- inplace=inplace,
- )
- if inplace:
- result = torch.rrelu_(input, lower, upper, training)
- else:
- result = torch.rrelu(input, lower, upper, training)
- return result
- rrelu_ = _add_docstr(
- torch.rrelu_,
- r"""
- rrelu_(input, lower=1./8, upper=1./3, training=False) -> Tensor
- In-place version of :func:`~rrelu`.
- """,
- )
- logsigmoid = _add_docstr(
- torch._C._nn.log_sigmoid,
- r"""
- logsigmoid(input) -> Tensor
- Applies element-wise :math:`\text{LogSigmoid}(x_i) = \log \left(\frac{1}{1 + \exp(-x_i)}\right)`
- See :class:`~torch.nn.LogSigmoid` for more details.
- """,
- )
- gelu = _add_docstr(
- torch._C._nn.gelu,
- r"""
- gelu(input, approximate = 'none') -> Tensor
- When the approximate argument is 'none', it applies element-wise the function
- :math:`\text{GELU}(x) = x * \Phi(x)`
- where :math:`\Phi(x)` is the Cumulative Distribution Function for Gaussian Distribution.
- When the approximate argument is 'tanh', Gelu is estimated with
- .. math::
- \text{GELU}(x) = 0.5 * x * (1 + \text{Tanh}(\sqrt{2 / \pi} * (x + 0.044715 * x^3)))
- See `Gaussian Error Linear Units (GELUs) <https://arxiv.org/abs/1606.08415>`_.
- """,
- )
- hardshrink = _add_docstr(
- torch.hardshrink,
- r"""
- hardshrink(input, lambd=0.5) -> Tensor
- Applies the hard shrinkage function element-wise
- See :class:`~torch.nn.Hardshrink` for more details.
- """,
- )
- def tanhshrink(input): # noqa: D400,D402
- r"""tanhshrink(input) -> Tensor
- Applies element-wise, :math:`\text{Tanhshrink}(x) = x - \text{Tanh}(x)`
- See :class:`~torch.nn.Tanhshrink` for more details.
- """
- if has_torch_function_unary(input):
- return handle_torch_function(tanhshrink, (input,), input)
- return input - input.tanh()
- def softsign(input): # noqa: D400,D402
- r"""softsign(input) -> Tensor
- Applies element-wise, the function :math:`\text{SoftSign}(x) = \frac{x}{1 + |x|}`
- See :class:`~torch.nn.Softsign` for more details.
- """
- if has_torch_function_unary(input):
- return handle_torch_function(softsign, (input,), input)
- return input / (input.abs() + 1)
- softplus = _add_docstr(
- torch._C._nn.softplus,
- r"""
- softplus(input, beta=1, threshold=20) -> Tensor
- Applies element-wise, the function :math:`\text{Softplus}(x) = \frac{1}{\beta} * \log(1 + \exp(\beta * x))`.
- For numerical stability the implementation reverts to the linear function
- when :math:`input \times \beta > threshold`.
- See :class:`~torch.nn.Softplus` for more details.
- """,
- )
- def _get_softmax_dim(name: str, ndim: int, stacklevel: int) -> int:
- warnings.warn(
- f"Implicit dimension choice for {name} has been deprecated. "
- "Change the call to include dim=X as an argument.",
- stacklevel=stacklevel,
- )
- if ndim == 0 or ndim == 1 or ndim == 3:
- ret = 0
- else:
- ret = 1
- return ret
- def softmin(
- input: Tensor,
- dim: Optional[int] = None,
- _stacklevel: int = 3,
- dtype: Optional[DType] = None,
- ) -> Tensor:
- r"""Apply a softmin function.
- Note that :math:`\text{Softmin}(x) = \text{Softmax}(-x)`. See softmax definition for mathematical formula.
- See :class:`~torch.nn.Softmin` for more details.
- Args:
- input (Tensor): input
- dim (int): A dimension along which softmin will be computed (so every slice
- along dim will sum to 1).
- dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
- If specified, the input tensor is casted to :attr:`dtype` before the operation
- is performed. This is useful for preventing data type overflows. Default: None.
- """
- if has_torch_function_unary(input):
- return handle_torch_function(
- softmin, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype
- )
- if dim is None:
- dim = _get_softmax_dim("softmin", input.dim(), _stacklevel)
- if dtype is None:
- ret = (-input).softmax(dim)
- else:
- ret = (-input).softmax(dim, dtype=dtype)
- return ret
- def softmax(
- input: Tensor,
- dim: Optional[int] = None,
- _stacklevel: int = 3,
- dtype: Optional[DType] = None,
- ) -> Tensor:
- r"""Apply a softmax function.
- Softmax is defined as:
- :math:`\text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}`
- It is applied to all slices along dim, and will re-scale them so that the elements
- lie in the range `[0, 1]` and sum to 1.
- See :class:`~torch.nn.Softmax` for more details.
- Args:
- input (Tensor): input
- dim (int): A dimension along which softmax will be computed.
- dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
- If specified, the input tensor is casted to :attr:`dtype` before the operation
- is performed. This is useful for preventing data type overflows. Default: None.
- .. note::
- This function doesn't work directly with NLLLoss,
- which expects the Log to be computed between the Softmax and itself.
- Use log_softmax instead (it's faster and has better numerical properties).
- """
- if has_torch_function_unary(input):
- return handle_torch_function(
- softmax, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype
- )
- if dim is None:
- dim = _get_softmax_dim("softmax", input.dim(), _stacklevel)
- if dtype is None:
- ret = input.softmax(dim)
- else:
- ret = input.softmax(dim, dtype=dtype)
- return ret
- def gumbel_softmax(
- logits: Tensor,
- tau: float = 1,
- hard: bool = False,
- eps: float = 1e-10,
- dim: int = -1,
- ) -> Tensor:
- r"""
- Sample from the Gumbel-Softmax distribution (`Link 1`_ `Link 2`_) and optionally discretize.
- Args:
- logits: `[..., num_features]` unnormalized log probabilities
- tau: non-negative scalar temperature
- hard: if ``True``, the returned samples will be discretized as one-hot vectors,
- but will be differentiated as if it is the soft sample in autograd
- dim (int): A dimension along which softmax will be computed. Default: -1.
- Returns:
- Sampled tensor of same shape as `logits` from the Gumbel-Softmax distribution.
- If ``hard=True``, the returned samples will be one-hot, otherwise they will
- be probability distributions that sum to 1 across `dim`.
- .. note::
- This function is here for legacy reasons, may be removed from nn.Functional in the future.
- .. note::
- The main trick for `hard` is to do `y_hard - y_soft.detach() + y_soft`
- It achieves two things:
- - makes the output value exactly one-hot
- (since we add then subtract y_soft value)
- - makes the gradient equal to y_soft gradient
- (since we strip all other gradients)
- Examples::
- >>> logits = torch.randn(20, 32)
- >>> # Sample soft categorical using reparametrization trick:
- >>> F.gumbel_softmax(logits, tau=1, hard=False)
- >>> # Sample hard categorical using "Straight-through" trick:
- >>> F.gumbel_softmax(logits, tau=1, hard=True)
- .. _Link 1:
- https://arxiv.org/abs/1611.00712
- .. _Link 2:
- https://arxiv.org/abs/1611.01144
- """
- if has_torch_function_unary(logits):
- return handle_torch_function(
- gumbel_softmax, (logits,), logits, tau=tau, hard=hard, eps=eps, dim=dim
- )
- if eps != 1e-10:
- warnings.warn("`eps` parameter is deprecated and has no effect.")
- gumbels = (
- -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format)
- .exponential_()
- .log()
- ) # ~Gumbel(0,1)
- gumbels = (logits + gumbels) / tau # ~Gumbel(logits,tau)
- y_soft = gumbels.softmax(dim)
- if hard:
- # Straight through.
- index = y_soft.max(dim, keepdim=True)[1]
- y_hard = torch.zeros_like(
- logits, memory_format=torch.legacy_contiguous_format
- ).scatter_(dim, index, 1.0)
- ret = y_hard - y_soft.detach() + y_soft
- else:
- # Reparametrization trick.
- ret = y_soft
- return ret
- def log_softmax(
- input: Tensor,
- dim: Optional[int] = None,
- _stacklevel: int = 3,
- dtype: Optional[DType] = None,
- ) -> Tensor:
- r"""Apply a softmax followed by a logarithm.
- While mathematically equivalent to log(softmax(x)), doing these two
- operations separately is slower and numerically unstable. This function
- uses an alternative formulation to compute the output and gradient correctly.
- See :class:`~torch.nn.LogSoftmax` for more details.
- Args:
- input (Tensor): input
- dim (int): A dimension along which log_softmax will be computed.
- dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
- If specified, the input tensor is cast to :attr:`dtype` before the operation
- is performed. This is useful for preventing data type overflows. Default: None.
- """
- if has_torch_function_unary(input):
- return handle_torch_function(
- log_softmax, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype
- )
- if dim is None:
- dim = _get_softmax_dim("log_softmax", input.dim(), _stacklevel)
- if dtype is None:
- ret = input.log_softmax(dim)
- else:
- ret = input.log_softmax(dim, dtype=dtype)
- return ret
- softshrink = _add_docstr(
- torch._C._nn.softshrink,
- r"""
- softshrink(input, lambd=0.5) -> Tensor
- Applies the soft shrinkage function elementwise
- See :class:`~torch.nn.Softshrink` for more details.
- """,
- )
- def tanh(input): # noqa: D400,D402
- r"""tanh(input) -> Tensor
- Applies element-wise,
- :math:`\text{Tanh}(x) = \tanh(x) = \frac{\exp(x) - \exp(-x)}{\exp(x) + \exp(-x)}`
- See :class:`~torch.nn.Tanh` for more details.
- """
- return input.tanh()
- def sigmoid(input): # noqa: D400,D402
- r"""sigmoid(input) -> Tensor
- Applies the element-wise function :math:`\text{Sigmoid}(x) = \frac{1}{1 + \exp(-x)}`
- See :class:`~torch.nn.Sigmoid` for more details.
- """
- return input.sigmoid()
- def hardsigmoid(input: Tensor, inplace: bool = False) -> Tensor:
- r"""Apply the Hardsigmoid function element-wise.
- .. math::
- \text{Hardsigmoid}(x) = \begin{cases}
- 0 & \text{if~} x \le -3, \\
- 1 & \text{if~} x \ge +3, \\
- x / 6 + 1 / 2 & \text{otherwise}
- \end{cases}
- Args:
- inplace: If set to ``True``, will do this operation in-place. Default: ``False``
- See :class:`~torch.nn.Hardsigmoid` for more details.
- """
- if has_torch_function_unary(input):
- return handle_torch_function(hardsigmoid, (input,), input, inplace=inplace)
- if inplace:
- return torch._C._nn.hardsigmoid_(input)
- return torch._C._nn.hardsigmoid(input)
- linear = _add_docstr(
- torch._C._nn.linear,
- r"""
- linear(input, weight, bias=None) -> Tensor
- Applies a linear transformation to the incoming data: :math:`y = xA^T + b`.
- This operation supports 2-D :attr:`weight` with :ref:`sparse layout<sparse-docs>`
- {sparse_beta_warning}
- This operator supports :ref:`TensorFloat32<tf32_on_ampere>`.
- Shape:
- - Input: :math:`(*, in\_features)` where `*` means any number of
- additional dimensions, including none
- - Weight: :math:`(out\_features, in\_features)` or :math:`(in\_features)`
- - Bias: :math:`(out\_features)` or :math:`()`
- - Output: :math:`(*, out\_features)` or :math:`(*)`, based on the shape of the weight
- """.format(**sparse_support_notes),
- )
- bilinear = _add_docstr(
- torch.bilinear,
- r"""
- bilinear(input1, input2, weight, bias=None) -> Tensor
- Applies a bilinear transformation to the incoming data:
- :math:`y = x_1^T A x_2 + b`
- Shape:
- - input1: :math:`(N, *, H_{in1})` where :math:`H_{in1}=\text{in1\_features}`
- and :math:`*` means any number of additional dimensions.
- All but the last dimension of the inputs should be the same.
- - input2: :math:`(N, *, H_{in2})` where :math:`H_{in2}=\text{in2\_features}`
- - weight: :math:`(\text{out\_features}, \text{in1\_features},
- \text{in2\_features})`
- - bias: :math:`(\text{out\_features})`
- - output: :math:`(N, *, H_{out})` where :math:`H_{out}=\text{out\_features}`
- and all but the last dimension are the same shape as the input.
- """,
- )
- def silu(input: Tensor, inplace: bool = False) -> Tensor:
- r"""Apply the Sigmoid Linear Unit (SiLU) function, element-wise.
- The SiLU function is also known as the swish function.
- .. math::
- \text{silu}(x) = x * \sigma(x), \text{where } \sigma(x) \text{ is the logistic sigmoid.}
- .. note::
- See `Gaussian Error Linear Units (GELUs) <https://arxiv.org/abs/1606.08415>`_
- where the SiLU (Sigmoid Linear Unit) was originally coined, and see
- `Sigmoid-Weighted Linear Units for Neural Network Function Approximation
- in Reinforcement Learning <https://arxiv.org/abs/1702.03118>`_ and `Swish:
- a Self-Gated Activation Function <https://arxiv.org/abs/1710.05941v1>`_
- where the SiLU was experimented with later.
- See :class:`~torch.nn.SiLU` for more details.
- """
- if has_torch_function_unary(input):
- return handle_torch_function(silu, (input,), input, inplace=inplace)
- if inplace:
- return torch._C._nn.silu_(input)
- return torch._C._nn.silu(input)
- def mish(input: Tensor, inplace: bool = False) -> Tensor:
- r"""Apply the Mish function, element-wise.
- Mish: A Self Regularized Non-Monotonic Neural Activation Function.
- .. math::
- \text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x))
- .. note::
- See `Mish: A Self Regularized Non-Monotonic Neural Activation Function <https://arxiv.org/abs/1908.08681>`_
- See :class:`~torch.nn.Mish` for more details.
- """
- if has_torch_function_unary(input):
- return handle_torch_function(mish, (input,), input, inplace=inplace)
- if inplace:
- return torch._C._nn.mish_(input)
- return torch._C._nn.mish(input)
- def hardswish(input: Tensor, inplace: bool = False) -> Tensor:
- r"""Apply hardswish function, element-wise.
- Follows implementation as described in the paper:
- `Searching for MobileNetV3`_.
- .. math::
- \text{Hardswish}(x) = \begin{cases}
- 0 & \text{if~} x \le -3, \\
- x & \text{if~} x \ge +3, \\
- x \cdot (x + 3) /6 & \text{otherwise}
- \end{cases}
- See :class:`~torch.nn.Hardswish` for more details.
- .. _`Searching for MobileNetV3`:
- https://arxiv.org/abs/1905.02244
- """
- if has_torch_function_unary(input):
- return handle_torch_function(hardswish, (input,), input, inplace=inplace)
- if inplace:
- return torch._C._nn.hardswish_(input)
- return torch._C._nn.hardswish(input)
- def _no_grad_embedding_renorm_(
- weight: Tensor,
- input: Tensor,
- max_norm: float,
- norm_type: float,
- ) -> tuple[Tensor, Tensor]:
- torch.embedding_renorm_(weight.detach(), input, max_norm, norm_type)
- def embedding(
- input: Tensor,
- weight: Tensor,
- padding_idx: Optional[int] = None,
- max_norm: Optional[float] = None,
- norm_type: float = 2.0,
- scale_grad_by_freq: bool = False,
- sparse: bool = False,
- ) -> Tensor:
- r"""Generate a simple lookup table that looks up embeddings in a fixed dictionary and size.
- This module is often used to retrieve word embeddings using indices.
- The input to the module is a list of indices, and the embedding matrix,
- and the output is the corresponding word embeddings.
- See :class:`torch.nn.Embedding` for more details.
- .. note::
- Note that the analytical gradients of this function with respect to
- entries in :attr:`weight` at the row specified by :attr:`padding_idx`
- are expected to differ from the numerical ones.
- .. note::
- Note that `:class:`torch.nn.Embedding` differs from this function in
- that it initializes the row of :attr:`weight` specified by
- :attr:`padding_idx` to all zeros on construction.
- Args:
- input (LongTensor): Tensor containing indices into the embedding matrix
- weight (Tensor): The embedding matrix with number of rows equal to the maximum possible index + 1,
- and number of columns equal to the embedding size
- padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the gradient;
- therefore, the embedding vector at :attr:`padding_idx` is not updated during training,
- i.e. it remains as a fixed "pad".
- max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm`
- is renormalized to have norm :attr:`max_norm`.
- Note: this will modify :attr:`weight` in-place.
- norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``.
- scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse of frequency of
- the words in the mini-batch. Default ``False``.
- sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` will be a sparse tensor. See Notes under
- :class:`torch.nn.Embedding` for more details regarding sparse gradients.
- Shape:
- - Input: LongTensor of arbitrary shape containing the indices to extract
- - Weight: Embedding matrix of floating point type with shape `(V, embedding_dim)`,
- where V = maximum index + 1 and embedding_dim = the embedding size
- - Output: `(*, embedding_dim)`, where `*` is the input shape
- Examples::
- >>> # a batch of 2 samples of 4 indices each
- >>> input = torch.tensor([[1, 2, 4, 5], [4, 3, 2, 9]])
- >>> # an embedding matrix containing 10 tensors of size 3
- >>> embedding_matrix = torch.rand(10, 3)
- >>> # xdoctest: +IGNORE_WANT("non-deterministic")
- >>> F.embedding(input, embedding_matrix)
- tensor([[[ 0.8490, 0.9625, 0.6753],
- [ 0.9666, 0.7761, 0.6108],
- [ 0.6246, 0.9751, 0.3618],
- [ 0.4161, 0.2419, 0.7383]],
- [[ 0.6246, 0.9751, 0.3618],
- [ 0.0237, 0.7794, 0.0528],
- [ 0.9666, 0.7761, 0.6108],
- [ 0.3385, 0.8612, 0.1867]]])
- >>> # example with padding_idx
- >>> weights = torch.rand(10, 3)
- >>> weights[0, :].zero_()
- >>> embedding_matrix = weights
- >>> input = torch.tensor([[0, 2, 0, 5]])
- >>> F.embedding(input, embedding_matrix, padding_idx=0)
- tensor([[[ 0.0000, 0.0000, 0.0000],
- [ 0.5609, 0.5384, 0.8720],
- [ 0.0000, 0.0000, 0.0000],
- [ 0.6262, 0.2438, 0.7471]]])
- """
- if has_torch_function_variadic(input, weight):
- return handle_torch_function(
- embedding,
- (input, weight),
- input,
- weight,
- padding_idx=padding_idx,
- max_norm=max_norm,
- norm_type=norm_type,
- scale_grad_by_freq=scale_grad_by_freq,
- sparse=sparse,
- )
- if padding_idx is not None:
- if padding_idx > 0:
- assert padding_idx < weight.size(0), (
- "Padding_idx must be within num_embeddings"
- )
- elif padding_idx < 0:
- assert padding_idx >= -weight.size(0), (
- "Padding_idx must be within num_embeddings"
- )
- padding_idx = weight.size(0) + padding_idx
- else:
- padding_idx = -1
- if max_norm is not None:
- # Note [embedding_renorm contiguous]
- # `embedding_renorm_` will call .contiguous() on input anyways, so we
- # call it here and take advantage of the improved locality in the
- # `embedding` call below too.
- input = input.contiguous()
- # Note [embedding_renorm set_grad_enabled]
- # XXX: equivalent to
- # with torch.no_grad():
- # torch.embedding_renorm_
- # remove once script supports set_grad_enabled
- _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
- return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
- def embedding_bag(
- input: Tensor,
- weight: Tensor,
- offsets: Optional[Tensor] = None,
- max_norm: Optional[float] = None,
- norm_type: float = 2,
- scale_grad_by_freq: bool = False,
- mode: str = "mean",
- sparse: bool = False,
- per_sample_weights: Optional[Tensor] = None,
- include_last_offset: bool = False,
- padding_idx: Optional[int] = None,
- ) -> Tensor:
- r"""Compute sums, means or maxes of `bags` of embeddings.
- Calculation is done without instantiating the intermediate embeddings.
- See :class:`torch.nn.EmbeddingBag` for more details.
- Note:
- {backward_reproducibility_note}
- Args:
- input (LongTensor): Tensor containing bags of indices into the embedding matrix
- weight (Tensor): The embedding matrix with number of rows equal to the maximum possible index + 1,
- and number of columns equal to the embedding size
- offsets (LongTensor, optional): Only used when :attr:`input` is 1D. :attr:`offsets` determines
- the starting index position of each bag (sequence) in :attr:`input`.
- max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm`
- is renormalized to have norm :attr:`max_norm`.
- Note: this will modify :attr:`weight` in-place.
- norm_type (float, optional): The ``p`` in the ``p``-norm to compute for the :attr:`max_norm` option.
- Default ``2``.
- scale_grad_by_freq (bool, optional): if given, this will scale gradients by the inverse of frequency of
- the words in the mini-batch. Default ``False``.
- Note: this option is not supported when ``mode="max"``.
- mode (str, optional): ``"sum"``, ``"mean"`` or ``"max"``. Specifies the way to reduce the bag.
- Default: ``"mean"``
- sparse (bool, optional): if ``True``, gradient w.r.t. :attr:`weight` will be a sparse tensor. See Notes under
- :class:`torch.nn.Embedding` for more details regarding sparse gradients.
- Note: this option is not supported when ``mode="max"``.
- per_sample_weights (Tensor, optional): a tensor of float / double weights, or None
- to indicate all weights should be taken to be 1. If specified, :attr:`per_sample_weights`
- must have exactly the same shape as input and is treated as having the same
- :attr:`offsets`, if those are not None.
- include_last_offset (bool, optional): if ``True``, the size of offsets is equal to the number of bags + 1.
- The last element is the size of the input, or the ending index position of the last bag (sequence).
- padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the
- gradient; therefore, the embedding vector at :attr:`padding_idx` is not updated
- during training, i.e. it remains as a fixed "pad". Note that the embedding
- vector at :attr:`padding_idx` is excluded from the reduction.
- Shape:
- - :attr:`input` (LongTensor) and :attr:`offsets` (LongTensor, optional)
- - If :attr:`input` is 2D of shape `(B, N)`, it will be treated as ``B`` bags (sequences)
- each of fixed length ``N``, and this will return ``B`` values aggregated in a way
- depending on the :attr:`mode`. :attr:`offsets` is ignored and required to be ``None`` in this case.
- - If :attr:`input` is 1D of shape `(N)`, it will be treated as a concatenation of
- multiple bags (sequences). :attr:`offsets` is required to be a 1D tensor containing
- the starting index positions of each bag in :attr:`input`. Therefore, for :attr:`offsets`
- of shape `(B)`, :attr:`input` will be viewed as having ``B`` bags.
- Empty bags (i.e., having 0-length) will have returned vectors filled by zeros.
- - :attr:`weight` (Tensor): the learnable weights of the module of shape `(num_embeddings, embedding_dim)`
- - :attr:`per_sample_weights` (Tensor, optional). Has the same shape as :attr:`input`.
- - :attr:`output`: aggregated embedding values of shape `(B, embedding_dim)`
- Examples::
- >>> # an Embedding module containing 10 tensors of size 3
- >>> embedding_matrix = torch.rand(10, 3)
- >>> # a batch of 2 samples of 4 indices each
- >>> input = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9])
- >>> offsets = torch.tensor([0, 4])
- >>> # xdoctest: +IGNORE_WANT("non-deterministic")
- >>> F.embedding_bag(input, embedding_matrix, offsets)
- tensor([[ 0.3397, 0.3552, 0.5545],
- [ 0.5893, 0.4386, 0.5882]])
- >>> # example with padding_idx
- >>> embedding_matrix = torch.rand(10, 3)
- >>> input = torch.tensor([2, 2, 2, 2, 4, 3, 2, 9])
- >>> offsets = torch.tensor([0, 4])
- >>> F.embedding_bag(input, embedding_matrix, offsets, padding_idx=2, mode='sum')
- tensor([[ 0.0000, 0.0000, 0.0000],
- [-0.7082, 3.2145, -2.6251]])
- """
- if has_torch_function_variadic(input, weight, offsets, per_sample_weights):
- return handle_torch_function(
- embedding_bag,
- (input, weight, offsets, per_sample_weights),
- input,
- weight,
- offsets=offsets,
- max_norm=max_norm,
- norm_type=norm_type,
- scale_grad_by_freq=scale_grad_by_freq,
- mode=mode,
- sparse=sparse,
- per_sample_weights=per_sample_weights,
- include_last_offset=include_last_offset,
- padding_idx=padding_idx,
- )
- # Check for backward compatibility.
- # Used to be embedding_bag(weight, input, ...)
- # Now is embedding_bag(input, weight, ...)
- if weight.dtype == torch.long and input.is_floating_point():
- warnings.warn(
- "Argument order of nn.functional.embedding_bag was changed. "
- "Usage `embedding_bag(weight, input, ...)` is deprecated, "
- "and should now be `embedding_bag(input, weight, ...)`."
- )
- weight, input = input, weight
- if per_sample_weights is not None and input.size() != per_sample_weights.size():
- raise ValueError(
- f"embedding_bag: If per_sample_weights ({per_sample_weights.shape}) is not None, "
- f"then it must have the same shape as the input ({input.shape})"
- )
- if not weight.dim() == 2:
- raise ValueError(
- f"weight has to be a 2D Tensor, but got Tensor of dimension {weight.dim()}"
- )
- if not torch.jit.is_scripting() and input.dim() == 2 and input.is_nested:
- include_last_offset = True
- offsets = input.offsets()
- input = input.values().reshape(-1)
- if per_sample_weights is not None:
- if not per_sample_weights.is_nested:
- raise ValueError(
- "If input is nested, then per_sample_weights must be nested if specified"
- )
- per_sample_weights = per_sample_weights.values().reshape(-1)
- elif input.dim() == 2:
- if offsets is not None:
- type_str = "<unknown>"
- # TODO: Remove this once script supports type() calls
- if not torch.jit.is_scripting():
- type_str = str(type(offsets))
- raise ValueError(
- "if input is 2D, then offsets has to be None"
- ", as input is treated is a mini-batch of"
- " fixed length sequences. However, found "
- f"offsets of type {type_str}"
- )
- offsets = torch.arange(
- 0, input.numel(), input.size(1), dtype=input.dtype, device=input.device
- )
- input = input.reshape(-1)
- if per_sample_weights is not None:
- per_sample_weights = per_sample_weights.reshape(-1)
- elif input.dim() == 1:
- if offsets is None:
- raise ValueError("offsets has to be a 1D Tensor but got None")
- if offsets.dim() != 1:
- raise ValueError("offsets has to be a 1D Tensor")
- else:
- raise ValueError(
- f"input has to be 1D or 2D Tensor, but got Tensor of dimension {input.dim()}"
- )
- if mode == "sum":
- mode_enum = 0
- elif mode == "mean":
- mode_enum = 1
- elif mode == "max":
- mode_enum = 2
- if scale_grad_by_freq:
- raise ValueError(
- "max mode does not support scaling the gradient by the frequency"
- )
- if sparse:
- raise ValueError("max mode does not support sparse weights")
- else:
- raise ValueError("mode has to be one of sum, mean or max")
- if max_norm is not None:
- # XXX: equivalent to
- # with torch.no_grad():
- # torch.nembedding_renorm_
- # remove once script supports set_grad_enabled
- _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
- if per_sample_weights is not None and mode != "sum":
- raise NotImplementedError(
- "embedding_bag: per_sample_weights was not None. "
- "per_sample_weights is only supported for mode='sum' "
- f"(got mode='{mode}'). Please open a feature request on GitHub."
- )
- ret, _, _, _ = torch.embedding_bag(
- weight,
- input,
- offsets,
- scale_grad_by_freq,
- mode_enum,
- sparse,
- per_sample_weights,
- include_last_offset,
- padding_idx,
- )
- return ret
- if embedding_bag.__doc__:
- embedding_bag.__doc__ = embedding_bag.__doc__.format(**reproducibility_notes)
- def _verify_batch_size(size: list[int]) -> None:
- # XXX: JIT script does not support the reduce from functools, and mul op is a
- # builtin, which cannot be used as a value to a func yet, so rewrite this size
- # check to a simple equivalent for loop
- #
- # TODO: make use of reduce like below when JIT is ready with the missing features:
- # from operator import mul
- # from functools import reduce
- #
- # if reduce(mul, size[2:], size[0]) == 1
- size_prods = size[0]
- for i in range(len(size) - 2):
- size_prods *= size[i + 2]
- if size_prods == 1:
- raise ValueError(
- f"Expected more than 1 value per channel when training, got input size {size}"
- )
- def batch_norm(
- input: Tensor,
- running_mean: Optional[Tensor],
- running_var: Optional[Tensor],
- weight: Optional[Tensor] = None,
- bias: Optional[Tensor] = None,
- training: bool = False,
- momentum: float = 0.1,
- eps: float = 1e-5,
- ) -> Tensor:
- r"""Apply Batch Normalization for each channel across a batch of data.
- See :class:`~torch.nn.BatchNorm1d`, :class:`~torch.nn.BatchNorm2d`,
- :class:`~torch.nn.BatchNorm3d` for details.
- """
- if has_torch_function_variadic(input, running_mean, running_var, weight, bias):
- return handle_torch_function(
- batch_norm,
- (input, running_mean, running_var, weight, bias),
- input,
- running_mean,
- running_var,
- weight=weight,
- bias=bias,
- training=training,
- momentum=momentum,
- eps=eps,
- )
- if training:
- _verify_batch_size(input.size())
- return torch.batch_norm(
- input,
- weight,
- bias,
- running_mean,
- running_var,
- training,
- momentum,
- eps,
- torch.backends.cudnn.enabled,
- )
- def _verify_spatial_size(size: list[int]) -> None:
- # Verify that there is > 1 spatial element for instance norm calculation.
- size_prods = 1
- for i in range(2, len(size)):
- size_prods *= size[i]
- if size_prods == 1:
- raise ValueError(
- f"Expected more than 1 spatial element when training, got input size {size}"
- )
- def instance_norm(
- input: Tensor,
- running_mean: Optional[Tensor] = None,
- running_var: Optional[Tensor] = None,
- weight: Optional[Tensor] = None,
- bias: Optional[Tensor] = None,
- use_input_stats: bool = True,
- momentum: float = 0.1,
- eps: float = 1e-5,
- ) -> Tensor:
- r"""Apply Instance Normalization independently for each channel in every data sample within a batch.
- See :class:`~torch.nn.InstanceNorm1d`, :class:`~torch.nn.InstanceNorm2d`,
- :class:`~torch.nn.InstanceNorm3d` for details.
- """
- if has_torch_function_variadic(input, running_mean, running_var, weight, bias):
- return handle_torch_function(
- instance_norm,
- (input, running_mean, running_var, weight, bias),
- input,
- running_mean=running_mean,
- running_var=running_var,
- weight=weight,
- bias=bias,
- use_input_stats=use_input_stats,
- momentum=momentum,
- eps=eps,
- )
- if use_input_stats:
- _verify_spatial_size(input.size())
- return torch.instance_norm(
- input,
- weight,
- bias,
- running_mean,
- running_var,
- use_input_stats,
- momentum,
- eps,
- torch.backends.cudnn.enabled,
- )
- def layer_norm(
- input: Tensor,
- normalized_shape: list[int],
- weight: Optional[Tensor] = None,
- bias: Optional[Tensor] = None,
- eps: float = 1e-5,
- ) -> Tensor:
- r"""Apply Layer Normalization for last certain number of dimensions.
- See :class:`~torch.nn.LayerNorm` for details.
- """
- if has_torch_function_variadic(input, weight, bias):
- return handle_torch_function(
- layer_norm,
- (input, weight, bias),
- input,
- normalized_shape,
- weight=weight,
- bias=bias,
- eps=eps,
- )
- return torch.layer_norm(
- input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled
- )
- def rms_norm(
- input: Tensor,
- normalized_shape: list[int],
- weight: Optional[Tensor] = None,
- eps: Optional[float] = None,
- ) -> Tensor:
- r"""Apply Root Mean Square Layer Normalization.
- See :class:`~torch.nn.RMSNorm` for details.
- """
- if has_torch_function_variadic(input, weight):
- return handle_torch_function(
- rms_norm, (input, weight), input, normalized_shape, weight=weight, eps=eps
- )
- return torch.rms_norm(input, normalized_shape, weight, eps)
- def group_norm(
- input: Tensor,
- num_groups: int,
- weight: Optional[Tensor] = None,
- bias: Optional[Tensor] = None,
- eps: float = 1e-5,
- ) -> Tensor:
- r"""Apply Group Normalization for last certain number of dimensions.
- See :class:`~torch.nn.GroupNorm` for details.
- """
- if has_torch_function_variadic(input, weight, bias):
- return handle_torch_function(
- group_norm,
- (
- input,
- weight,
- bias,
- ),
- input,
- num_groups,
- weight=weight,
- bias=bias,
- eps=eps,
- )
- if input.dim() < 2:
- raise RuntimeError(
- f"Expected at least 2 dimensions for input tensor but received {input.dim()}"
- )
- _verify_batch_size(
- [input.size(0) * input.size(1) // num_groups, num_groups]
- + list(input.size()[2:])
- )
- return torch.group_norm(
- input, num_groups, weight, bias, eps, torch.backends.cudnn.enabled
- )
- def local_response_norm(
- input: Tensor,
- size: int,
- alpha: float = 1e-4,
- beta: float = 0.75,
- k: float = 1.0,
- ) -> Tensor:
- r"""Apply local response normalization over an input signal.
- The input signal is composed of several input planes, where channels occupy the second dimension.
- Normalization is applied across channels.
- See :class:`~torch.nn.LocalResponseNorm` for details.
- """
- if has_torch_function_unary(input):
- return handle_torch_function(
- local_response_norm, (input,), input, size, alpha=alpha, beta=beta, k=k
- )
- dim = input.dim()
- if dim < 3:
- raise ValueError(
- f"Expected 3D or higher dimensionality input (got {dim} dimensions)"
- )
- if input.numel() == 0:
- return input
- div = input.mul(input)
- if dim == 3:
- div = div.unsqueeze(1)
- div = pad(div, (0, 0, size // 2, (size - 1) // 2))
- div = avg_pool2d(div, (size, 1), stride=1).squeeze(1)
- else:
- sizes = input.size()
- div = div.view(sizes[0], 1, sizes[1], sizes[2], -1)
- div = pad(div, (0, 0, 0, 0, size // 2, (size - 1) // 2))
- div = avg_pool3d(div, (size, 1, 1), stride=1).squeeze(1)
- div = div.view(sizes)
- div = div.mul(alpha).add(k).pow(beta)
- return input / div
- # loss
- def ctc_loss(
- log_probs: Tensor,
- targets: Tensor,
- input_lengths: Tensor,
- target_lengths: Tensor,
- blank: int = 0,
- reduction: str = "mean",
- zero_infinity: bool = False,
- ) -> Tensor:
- r"""Compute the Connectionist Temporal Classification loss.
- See :class:`~torch.nn.CTCLoss` for details.
- Note:
- {cudnn_reproducibility_note}
- Note:
- {backward_reproducibility_note}
- Args:
- log_probs: :math:`(T, N, C)` or :math:`(T, C)` where `C = number of characters in alphabet including blank`,
- `T = input length`, and `N = batch size`.
- The logarithmized probabilities of the outputs
- (e.g. obtained with :func:`torch.nn.functional.log_softmax`).
- targets: :math:`(N, S)` or `(sum(target_lengths))`.
- May be an empty tensor if all entries in `target_lengths` are zero.
- In the second form, the targets are assumed to be concatenated.
- input_lengths: :math:`(N)` or :math:`()`.
- Lengths of the inputs (must each be :math:`\leq T`)
- target_lengths: :math:`(N)` or :math:`()`.
- Lengths of the targets
- blank (int, optional):
- Blank label. Default :math:`0`.
- reduction (str, optional): Specifies the reduction to apply to the output:
- ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
- ``'mean'``: the output losses will be divided by the target lengths and
- then the mean over the batch is taken, ``'sum'``: the output will be
- summed. Default: ``'mean'``
- zero_infinity (bool, optional):
- Whether to zero infinite losses and the associated gradients.
- Default: ``False``
- Infinite losses mainly occur when the inputs are too short
- to be aligned to the targets.
- Example::
- >>> log_probs = torch.randn(50, 16, 20).log_softmax(2).detach().requires_grad_()
- >>> targets = torch.randint(1, 20, (16, 30), dtype=torch.long)
- >>> input_lengths = torch.full((16,), 50, dtype=torch.long)
- >>> target_lengths = torch.randint(10, 30, (16,), dtype=torch.long)
- >>> loss = F.ctc_loss(log_probs, targets, input_lengths, target_lengths)
- >>> loss.backward()
- """
- if has_torch_function_variadic(log_probs, targets, input_lengths, target_lengths):
- return handle_torch_function(
- ctc_loss,
- (log_probs, targets, input_lengths, target_lengths),
- log_probs,
- targets,
- input_lengths,
- target_lengths,
- blank=blank,
- reduction=reduction,
- zero_infinity=zero_infinity,
- )
- return torch.ctc_loss(
- log_probs,
- targets,
- input_lengths,
- target_lengths,
- blank,
- _Reduction.get_enum(reduction),
- zero_infinity,
- )
- if ctc_loss.__doc__:
- ctc_loss.__doc__ = ctc_loss.__doc__.format(**reproducibility_notes)
- def nll_loss(
- input: Tensor,
- target: Tensor,
- weight: Optional[Tensor] = None,
- size_average: Optional[bool] = None,
- ignore_index: int = -100,
- reduce: Optional[bool] = None,
- reduction: str = "mean",
- ) -> Tensor:
- r"""Compute the negative log likelihood loss.
- See :class:`~torch.nn.NLLLoss` for details.
- Args:
- input: :math:`(N, C)` where `C = number of classes` or :math:`(N, C, H, W)`
- in case of 2D Loss, or :math:`(N, C, d_1, d_2, ..., d_K)` where :math:`K \geq 1`
- in the case of K-dimensional loss. `input` is expected to be log-probabilities.
- target: :math:`(N)` where each value is :math:`0 \leq \text{targets}[i] \leq C-1`,
- or :math:`(N, d_1, d_2, ..., d_K)` where :math:`K \geq 1` for
- K-dimensional loss.
- weight (Tensor, optional): A manual rescaling weight given to each
- class. If given, has to be a Tensor of size `C`
- size_average (bool, optional): Deprecated (see :attr:`reduction`).
- ignore_index (int, optional): Specifies a target value that is ignored
- and does not contribute to the input gradient. When :attr:`size_average` is
- ``True``, the loss is averaged over non-ignored targets. Default: -100
- reduce (bool, optional): Deprecated (see :attr:`reduction`).
- reduction (str, optional): Specifies the reduction to apply to the output:
- ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
- ``'mean'``: the sum of the output will be divided by the number of
- elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
- and :attr:`reduce` are in the process of being deprecated, and in the meantime,
- specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
- Example::
- >>> # input is of size N x C = 3 x 5
- >>> input = torch.randn(3, 5, requires_grad=True)
- >>> # each element in target has to have 0 <= value < C
- >>> target = torch.tensor([1, 0, 4])
- >>> output = F.nll_loss(F.log_softmax(input, dim=1), target)
- >>> output.backward()
- """
- if has_torch_function_variadic(input, target, weight):
- return handle_torch_function(
- nll_loss,
- (input, target, weight),
- input,
- target,
- weight=weight,
- size_average=size_average,
- ignore_index=ignore_index,
- reduce=reduce,
- reduction=reduction,
- )
- if size_average is not None or reduce is not None:
- reduction = _Reduction.legacy_get_string(size_average, reduce)
- return torch._C._nn.nll_loss_nd(
- input, target, weight, _Reduction.get_enum(reduction), ignore_index
- )
- def poisson_nll_loss(
- input: Tensor,
- target: Tensor,
- log_input: bool = True,
- full: bool = False,
- size_average: Optional[bool] = None,
- eps: float = 1e-8,
- reduce: Optional[bool] = None,
- reduction: str = "mean",
- ) -> Tensor:
- r"""Compute the Poisson negative log likelihood loss.
- See :class:`~torch.nn.PoissonNLLLoss` for details.
- Args:
- input: Expectation of underlying Poisson distribution.
- target: Random sample :math:`target \sim \text{Poisson}(input)`.
- log_input: If ``True`` the loss is computed as
- :math:`\exp(\text{input}) - \text{target} * \text{input}`, if ``False`` then loss is
- :math:`\text{input} - \text{target} * \log(\text{input}+\text{eps})`. Default: ``True``
- full: Whether to compute full loss, i. e. to add the Stirling
- approximation term. Default: ``False``
- :math:`\text{target} * \log(\text{target}) - \text{target} + 0.5 * \log(2 * \pi * \text{target})`.
- size_average (bool, optional): Deprecated (see :attr:`reduction`).
- eps (float, optional): Small value to avoid evaluation of :math:`\log(0)` when
- :attr:`log_input`\ =\ ``False``. Default: 1e-8
- reduce (bool, optional): Deprecated (see :attr:`reduction`).
- reduction (str, optional): Specifies the reduction to apply to the output:
- ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
- ``'mean'``: the sum of the output will be divided by the number of
- elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
- and :attr:`reduce` are in the process of being deprecated, and in the meantime,
- specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
- """
- if has_torch_function_variadic(input, target):
- return handle_torch_function(
- poisson_nll_loss,
- (input, target),
- input,
- target,
- log_input=log_input,
- full=full,
- size_average=size_average,
- eps=eps,
- reduce=reduce,
- reduction=reduction,
- )
- if size_average is not None or reduce is not None:
- reduction = _Reduction.legacy_get_string(size_average, reduce)
- if reduction != "none" and reduction != "mean" and reduction != "sum":
- ret = input
- raise ValueError(reduction + " is not a valid value for reduction")
- ret = torch.poisson_nll_loss(
- input, target, log_input, full, eps, _Reduction.get_enum(reduction)
- )
- return ret
- def gaussian_nll_loss(
- input: Tensor,
- target: Tensor,
- var: Union[Tensor, float],
- full: bool = False,
- eps: float = 1e-6,
- reduction: str = "mean",
- ) -> Tensor:
- r"""Compute the Gaussian negative log likelihood loss.
- See :class:`~torch.nn.GaussianNLLLoss` for details.
- Args:
- input: Expectation of the Gaussian distribution.
- target: Sample from the Gaussian distribution.
- var: Tensor of positive variance(s), one for each of the expectations
- in the input (heteroscedastic), or a single one (homoscedastic),
- or a positive scalar value to be used for all expectations.
- full (bool, optional): Whether to include the constant term in the loss calculation. Default: ``False``.
- eps (float, optional): Value added to var, for stability. Default: 1e-6.
- reduction (str, optional): Specifies the reduction to apply to the output:
- ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
- ``'mean'``: the output is the average of all batch member losses,
- ``'sum'``: the output is the sum of all batch member losses.
- Default: ``'mean'``.
- """
- if has_torch_function_variadic(input, target, var):
- return handle_torch_function(
- gaussian_nll_loss,
- (input, target, var),
- input,
- target,
- var,
- full=full,
- eps=eps,
- reduction=reduction,
- )
- # Entries of var must be non-negative
- if isinstance(var, float):
- if var < 0:
- raise ValueError("var has negative entry/entries")
- var = var * torch.ones_like(input)
- elif torch.any(var < 0):
- raise ValueError("var has negative entry/entries")
- # Check var size
- # If var.size == input.size, the case is heteroscedastic and no further checks are needed.
- # Otherwise:
- if var.size() != input.size():
- # If var is one dimension short of input, but the sizes match otherwise, then this is a homoscedastic case.
- # e.g. input.size = (10, 2, 3), var.size = (10, 2)
- # -> unsqueeze var so that var.shape = (10, 2, 1)
- # this is done so that broadcasting can happen in the loss calculation
- if input.size()[:-1] == var.size():
- var = torch.unsqueeze(var, -1)
- # This checks if the sizes match up to the final dimension, and the final dimension of var is of size 1.
- # This is also a homoscedastic case.
- # e.g. input.size = (10, 2, 3), var.size = (10, 2, 1)
- elif (
- input.size()[:-1] == var.size()[:-1] and var.size(-1) == 1
- ): # Heteroscedastic case
- pass
- # If none of the above pass, then the size of var is incorrect.
- else:
- raise ValueError("var is of incorrect size")
- # Check validity of reduction mode
- if reduction != "none" and reduction != "mean" and reduction != "sum":
- raise ValueError(reduction + " is not valid")
- # Clamp for stability
- var = var.clone()
- with torch.no_grad():
- var.clamp_(min=eps)
- # Calculate the loss
- loss = 0.5 * (torch.log(var) + (input - target) ** 2 / var)
- if full:
- loss += 0.5 * math.log(2 * math.pi)
- if reduction == "mean":
- return loss.mean()
- elif reduction == "sum":
- return loss.sum()
- else:
- return loss
- def kl_div(
- input: Tensor,
- target: Tensor,
- size_average: Optional[bool] = None,
- reduce: Optional[bool] = None,
- reduction: str = "mean",
- log_target: bool = False,
- ) -> Tensor:
- r"""Compute the KL Divergence loss.
- Refer - The `Kullback-Leibler divergence Loss
- <https://en.wikipedia.org/wiki/Kullback-Leibler_divergence>`__
- See :class:`~torch.nn.KLDivLoss` for details.
- Args:
- input: Tensor of arbitrary shape in log-probabilities.
- target: Tensor of the same shape as input. See :attr:`log_target` for
- the target's interpretation.
- size_average (bool, optional): Deprecated (see :attr:`reduction`).
- reduce (bool, optional): Deprecated (see :attr:`reduction`).
- reduction (str, optional): Specifies the reduction to apply to the output:
- ``'none'`` | ``'batchmean'`` | ``'sum'`` | ``'mean'``.
- ``'none'``: no reduction will be applied
- ``'batchmean'``: the sum of the output will be divided by the batchsize
- ``'sum'``: the output will be summed
- ``'mean'``: the output will be divided by the number of elements in the output
- Default: ``'mean'``
- log_target (bool): A flag indicating whether ``target`` is passed in the log space.
- It is recommended to pass certain distributions (like ``softmax``)
- in the log space to avoid numerical issues caused by explicit ``log``.
- Default: ``False``
- .. note::
- :attr:`size_average` and :attr:`reduce` are in the process of being deprecated,
- and in the meantime, specifying either of those two args will override :attr:`reduction`.
- .. warning::
- :attr:`reduction` = ``'mean'`` doesn't return the true kl divergence value, please use
- :attr:`reduction` = ``'batchmean'`` which aligns with KL math definition.
- """
- if has_torch_function_variadic(input, target):
- return handle_torch_function(
- kl_div,
- (input, target),
- input,
- target,
- size_average=size_average,
- reduce=reduce,
- reduction=reduction,
- log_target=log_target,
- )
- if size_average is not None or reduce is not None:
- reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
- else:
- if reduction == "mean":
- warnings.warn(
- "reduction: 'mean' divides the total loss by both the batch size and the support size."
- "'batchmean' divides only by the batch size, and aligns with the KL div math definition."
- "'mean' will be changed to behave the same as 'batchmean' in the next major release."
- )
- # special case for batchmean
- if reduction == "batchmean":
- reduction_enum = _Reduction.get_enum("sum")
- else:
- reduction_enum = _Reduction.get_enum(reduction)
- reduced = torch.kl_div(input, target, reduction_enum, log_target=log_target)
- if reduction == "batchmean" and input.dim() != 0:
- reduced = reduced / input.size()[0]
- return reduced
- def cross_entropy(
- input: Tensor,
- target: Tensor,
- weight: Optional[Tensor] = None,
- size_average: Optional[bool] = None,
- ignore_index: int = -100,
- reduce: Optional[bool] = None,
- reduction: str = "mean",
- label_smoothing: float = 0.0,
- ) -> Tensor:
- r"""Compute the cross entropy loss between input logits and target.
- See :class:`~torch.nn.CrossEntropyLoss` for details.
- Args:
- input (Tensor) : Predicted unnormalized logits;
- see Shape section below for supported shapes.
- target (Tensor) : Ground truth class indices or class probabilities;
- see Shape section below for supported shapes.
- weight (Tensor, optional): a manual rescaling weight given to each
- class. If given, has to be a Tensor of size `C`
- size_average (bool, optional): Deprecated (see :attr:`reduction`).
- ignore_index (int, optional): Specifies a target value that is ignored
- and does not contribute to the input gradient. When :attr:`size_average` is
- ``True``, the loss is averaged over non-ignored targets. Note that
- :attr:`ignore_index` is only applicable when the target contains class indices.
- Default: -100
- reduce (bool, optional): Deprecated (see :attr:`reduction`).
- reduction (str, optional): Specifies the reduction to apply to the output:
- ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
- ``'mean'``: the sum of the output will be divided by the number of
- elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
- and :attr:`reduce` are in the process of being deprecated, and in the meantime,
- specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
- label_smoothing (float, optional): A float in [0.0, 1.0]. Specifies the amount
- of smoothing when computing the loss, where 0.0 means no smoothing. The targets
- become a mixture of the original ground truth and a uniform distribution as described in
- `Rethinking the Inception Architecture for Computer Vision <https://arxiv.org/abs/1512.00567>`__. Default: :math:`0.0`.
- Shape:
- - Input: Shape :math:`(C)`, :math:`(N, C)` or :math:`(N, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1`
- in the case of `K`-dimensional loss.
- - Target: If containing class indices, shape :math:`()`, :math:`(N)` or :math:`(N, d_1, d_2, ..., d_K)` with
- :math:`K \geq 1` in the case of K-dimensional loss where each value should be between :math:`[0, C)`.
- If containing class probabilities, same shape as the input and each value should be between :math:`[0, 1]`.
- where:
- .. math::
- \begin{aligned}
- C ={} & \text{number of classes} \\
- N ={} & \text{batch size} \\
- \end{aligned}
- Examples::
- >>> # Example of target with class indices
- >>> input = torch.randn(3, 5, requires_grad=True)
- >>> target = torch.randint(5, (3,), dtype=torch.int64)
- >>> loss = F.cross_entropy(input, target)
- >>> loss.backward()
- >>>
- >>> # Example of target with class probabilities
- >>> input = torch.randn(3, 5, requires_grad=True)
- >>> target = torch.randn(3, 5).softmax(dim=1)
- >>> loss = F.cross_entropy(input, target)
- >>> loss.backward()
- """
- if has_torch_function_variadic(input, target, weight):
- return handle_torch_function(
- cross_entropy,
- (input, target, weight),
- input,
- target,
- weight=weight,
- size_average=size_average,
- ignore_index=ignore_index,
- reduce=reduce,
- reduction=reduction,
- label_smoothing=label_smoothing,
- )
- if size_average is not None or reduce is not None:
- reduction = _Reduction.legacy_get_string(size_average, reduce)
- return torch._C._nn.cross_entropy_loss(
- input,
- target,
- weight,
- _Reduction.get_enum(reduction),
- ignore_index,
- label_smoothing,
- )
- def binary_cross_entropy(
- input: Tensor,
- target: Tensor,
- weight: Optional[Tensor] = None,
- size_average: Optional[bool] = None,
- reduce: Optional[bool] = None,
- reduction: str = "mean",
- ) -> Tensor:
- r"""Compute Binary Cross Entropy between the target and input probabilities.
- See :class:`~torch.nn.BCELoss` for details.
- Args:
- input: Tensor of arbitrary shape as probabilities.
- target: Tensor of the same shape as input with values between 0 and 1.
- weight (Tensor, optional): a manual rescaling weight
- if provided it's repeated to match input tensor shape
- size_average (bool, optional): Deprecated (see :attr:`reduction`).
- reduce (bool, optional): Deprecated (see :attr:`reduction`).
- reduction (str, optional): Specifies the reduction to apply to the output:
- ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
- ``'mean'``: the sum of the output will be divided by the number of
- elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
- and :attr:`reduce` are in the process of being deprecated, and in the meantime,
- specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
- Examples::
- >>> input = torch.randn(3, 2, requires_grad=True)
- >>> target = torch.rand(3, 2, requires_grad=False)
- >>> loss = F.binary_cross_entropy(torch.sigmoid(input), target)
- >>> loss.backward()
- """
- if has_torch_function_variadic(input, target, weight):
- return handle_torch_function(
- binary_cross_entropy,
- (input, target, weight),
- input,
- target,
- weight=weight,
- size_average=size_average,
- reduce=reduce,
- reduction=reduction,
- )
- if size_average is not None or reduce is not None:
- reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
- else:
- reduction_enum = _Reduction.get_enum(reduction)
- if target.size() != input.size():
- raise ValueError(
- f"Using a target size ({target.size()}) that is different to the input size ({input.size()}) is deprecated. "
- "Please ensure they have the same size."
- )
- if weight is not None:
- new_size = _infer_size(target.size(), weight.size())
- weight = weight.expand(new_size)
- return torch._C._nn.binary_cross_entropy(input, target, weight, reduction_enum)
- def binary_cross_entropy_with_logits(
- input: Tensor,
- target: Tensor,
- weight: Optional[Tensor] = None,
- size_average: Optional[bool] = None,
- reduce: Optional[bool] = None,
- reduction: str = "mean",
- pos_weight: Optional[Tensor] = None,
- ) -> Tensor:
- r"""Compute Binary Cross Entropy between target and input logits.
- See :class:`~torch.nn.BCEWithLogitsLoss` for details.
- Args:
- input: Tensor of arbitrary shape as unnormalized scores (often referred to as logits).
- target: Tensor of the same shape as input with values between 0 and 1
- weight (Tensor, optional): a manual rescaling weight
- if provided it's repeated to match input tensor shape
- size_average (bool, optional): Deprecated (see :attr:`reduction`).
- reduce (bool, optional): Deprecated (see :attr:`reduction`).
- reduction (str, optional): Specifies the reduction to apply to the output:
- ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
- ``'mean'``: the sum of the output will be divided by the number of
- elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
- and :attr:`reduce` are in the process of being deprecated, and in the meantime,
- specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
- pos_weight (Tensor, optional): a weight of positive examples to be broadcasted with target.
- Must be a tensor with equal size along the class dimension to the number of classes.
- Pay close attention to PyTorch's broadcasting semantics in order to achieve the desired
- operations. For a target of size [B, C, H, W] (where B is batch size) pos_weight of
- size [B, C, H, W] will apply different pos_weights to each element of the batch or
- [C, H, W] the same pos_weights across the batch. To apply the same positive weight
- along all spatial dimensions for a 2D multi-class target [C, H, W] use: [C, 1, 1].
- Default: ``None``
- Examples::
- >>> input = torch.randn(3, requires_grad=True)
- >>> target = torch.empty(3).random_(2)
- >>> loss = F.binary_cross_entropy_with_logits(input, target)
- >>> loss.backward()
- """
- if has_torch_function_variadic(input, target, weight, pos_weight):
- return handle_torch_function(
- binary_cross_entropy_with_logits,
- (input, target, weight, pos_weight),
- input,
- target,
- weight=weight,
- size_average=size_average,
- reduce=reduce,
- reduction=reduction,
- pos_weight=pos_weight,
- )
- if size_average is not None or reduce is not None:
- reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
- else:
- reduction_enum = _Reduction.get_enum(reduction)
- if not (target.size() == input.size()):
- raise ValueError(
- f"Target size ({target.size()}) must be the same as input size ({input.size()})"
- )
- return torch.binary_cross_entropy_with_logits(
- input, target, weight, pos_weight, reduction_enum
- )
- def smooth_l1_loss(
- input: Tensor,
- target: Tensor,
- size_average: Optional[bool] = None,
- reduce: Optional[bool] = None,
- reduction: str = "mean",
- beta: float = 1.0,
- ) -> Tensor:
- r"""Compute the Smooth L1 loss.
- Function uses a squared term if the absolute
- element-wise error falls below beta and an L1 term otherwise.
- See :class:`~torch.nn.SmoothL1Loss` for details.
- Args:
- input (Tensor): Predicted values.
- target (Tensor): Ground truth values.
- size_average (bool, optional): Deprecated (see :attr:`reduction`).
- reduce (bool, optional): Deprecated (see :attr:`reduction`).
- reduction (str, optional): Specifies the reduction to apply to the output:
- 'none' | 'mean' | 'sum'. 'mean': the mean of the output is taken.
- 'sum': the output will be summed. 'none': no reduction will be applied.
- Default: 'mean'.
- beta (float, optional): Specifies the threshold at which to change from the squared
- term to the L1 term in the loss calculation. This value must be positive.
- Default: 1.0.
- Returns:
- Tensor: L1 loss (optionally weighted).
- """
- if has_torch_function_variadic(input, target):
- return handle_torch_function(
- smooth_l1_loss,
- (input, target),
- input,
- target,
- size_average=size_average,
- reduce=reduce,
- reduction=reduction,
- beta=beta,
- )
- if not (target.size() == input.size()):
- warnings.warn(
- f"Using a target size ({target.size()}) that is different to the input size ({input.size()}). "
- "This will likely lead to incorrect results due to broadcasting. "
- "Please ensure they have the same size.",
- stacklevel=2,
- )
- if size_average is not None or reduce is not None:
- reduction = _Reduction.legacy_get_string(size_average, reduce)
- expanded_input, expanded_target = torch.broadcast_tensors(input, target)
- if beta == 0.0:
- return torch._C._nn.l1_loss(
- expanded_input, expanded_target, _Reduction.get_enum(reduction)
- )
- else:
- return torch._C._nn.smooth_l1_loss(
- expanded_input, expanded_target, _Reduction.get_enum(reduction), beta
- )
- def huber_loss(
- input: Tensor,
- target: Tensor,
- reduction: str = "mean",
- delta: float = 1.0,
- weight: Optional[Tensor] = None,
- ) -> Tensor:
- r"""Compute the Huber loss, with optional weighting.
- Function uses a squared term if the absolute
- element-wise error falls below delta and a delta-scaled L1 term otherwise.
- When delta equals 1, this loss is equivalent to SmoothL1Loss.
- In general, Huber loss differs from SmoothL1Loss by a factor of delta (AKA beta in Smooth L1).
- See :class:`~torch.nn.HuberLoss` for details.
- Args:
- input (Tensor): Predicted values.
- target (Tensor): Ground truth values.
- reduction (str, optional): Specifies the reduction to apply to the output:
- 'none' | 'mean' | 'sum'. 'mean': the mean of the output is taken.
- 'sum': the output will be summed. 'none': no reduction will be applied.
- Default: 'mean'.
- delta (float, optional): The threshold at which to change between delta-scaled L1 and L2 loss. Default: 1.0.
- weight (Tensor, optional): Weights for each sample. Default: None.
- Returns:
- Tensor: Huber loss (optionally weighted).
- """
- if has_torch_function_variadic(input, target, weight):
- return handle_torch_function(
- huber_loss,
- (input, target, weight),
- input,
- target,
- reduction=reduction,
- delta=delta,
- weight=weight,
- )
- if not (target.size() == input.size()):
- warnings.warn(
- f"Using a target size ({target.size()}) that is different to the input size ({input.size()}). "
- "This will likely lead to incorrect results due to broadcasting. "
- "Please ensure they have the same size.",
- stacklevel=2,
- )
- expanded_input, expanded_target = torch.broadcast_tensors(input, target)
- if weight is None:
- # Use the optimized C++ backend for standard Huber loss
- return torch._C._nn.huber_loss(
- expanded_input, expanded_target, _Reduction.get_enum(reduction), delta
- )
- else:
- if weight.size() != input.size():
- raise ValueError("Weights and input must have the same size.")
- # Calculate the unweighted loss first
- unweighted_loss = torch._C._nn.huber_loss(
- expanded_input, expanded_target, _Reduction.get_enum("none"), delta
- )
- # Apply weight to the unweighted loss
- weighted_loss = unweighted_loss * weight
- if reduction == "none":
- return weighted_loss
- elif reduction == "sum":
- return torch.sum(weighted_loss)
- elif reduction == "mean":
- return weighted_loss.mean()
- else:
- raise ValueError(
- f"Invalid reduction mode: {reduction}. Expected one of 'none', 'mean', 'sum'."
- )
- def l1_loss(
- input: Tensor,
- target: Tensor,
- size_average: Optional[bool] = None,
- reduce: Optional[bool] = None,
- reduction: str = "mean",
- weight: Optional[Tensor] = None,
- ) -> Tensor: # noqa: D400,D402
- r"""Compute the L1 loss, with optional weighting.
- Function that takes the mean element-wise absolute value difference.
- See :class:`~torch.nn.L1Loss` for details.
- Args:
- input (Tensor): Predicted values.
- target (Tensor): Ground truth values.
- size_average (bool, optional): Deprecated (see :attr:`reduction`).
- reduce (bool, optional): Deprecated (see :attr:`reduction`).
- reduction (str, optional): Specifies the reduction to apply to the output:
- 'none' | 'mean' | 'sum'. 'mean': the mean of the output is taken.
- 'sum': the output will be summed. 'none': no reduction will be applied.
- Default: 'mean'.
- weight (Tensor, optional): Weights for each sample. Default: None.
- Returns:
- Tensor: L1 loss (optionally weighted).
- """
- if has_torch_function_variadic(input, target):
- return handle_torch_function(
- l1_loss,
- (input, target, weight),
- input,
- target,
- size_average=size_average,
- reduce=reduce,
- reduction=reduction,
- )
- if not (target.size() == input.size()):
- warnings.warn(
- f"Using a target size ({target.size()}) that is different to the input size ({input.size()}). "
- "This will likely lead to incorrect results due to broadcasting. "
- "Please ensure they have the same size.",
- stacklevel=2,
- )
- if size_average is not None or reduce is not None:
- reduction = _Reduction.legacy_get_string(size_average, reduce)
- expanded_input, expanded_target = torch.broadcast_tensors(input, target)
- if weight is not None:
- if weight.size() != input.size():
- raise ValueError("Weights and input must have the same size.")
- absolute_errors = torch.abs(expanded_input - expanded_target)
- weighted_absolute_errors = absolute_errors * weight
- if reduction == "none":
- return weighted_absolute_errors
- elif reduction == "sum":
- return torch.sum(weighted_absolute_errors)
- elif reduction == "mean":
- return torch.sum(weighted_absolute_errors) / torch.sum(weight)
- else:
- raise ValueError(
- f"Invalid reduction mode: {reduction}. Expected one of 'none', 'mean', 'sum'."
- )
- else:
- return torch._C._nn.l1_loss(
- expanded_input, expanded_target, _Reduction.get_enum(reduction)
- )
- def mse_loss(
- input: Tensor,
- target: Tensor,
- size_average: Optional[bool] = None,
- reduce: Optional[bool] = None,
- reduction: str = "mean",
- weight: Optional[Tensor] = None,
- ) -> Tensor:
- r"""Compute the element-wise mean squared error, with optional weighting.
- See :class:`~torch.nn.MSELoss` for details.
- Args:
- input (Tensor): Predicted values.
- target (Tensor): Ground truth values.
- size_average (bool, optional): Deprecated (see :attr:`reduction`).
- reduce (bool, optional): Deprecated (see :attr:`reduction`).
- reduction (str, optional): Specifies the reduction to apply to the output:
- 'none' | 'mean' | 'sum'. 'mean': the mean of the output is taken.
- 'sum': the output will be summed. 'none': no reduction will be applied.
- Default: 'mean'.
- weight (Tensor, optional): Weights for each sample. Default: None.
- Returns:
- Tensor: Mean Squared Error loss (optionally weighted).
- """
- if has_torch_function_variadic(input, target, weight):
- return handle_torch_function(
- mse_loss,
- (input, target, weight),
- input,
- target,
- size_average=size_average,
- reduce=reduce,
- reduction=reduction,
- weight=weight,
- )
- if not (target.size() == input.size()):
- warnings.warn(
- f"Using a target size ({target.size()}) that is different to the input size ({input.size()}). "
- "This will likely lead to incorrect results due to broadcasting. "
- "Please ensure they have the same size.",
- stacklevel=2,
- )
- if size_average is not None or reduce is not None:
- reduction = _Reduction.legacy_get_string(size_average, reduce)
- expanded_input, expanded_target = torch.broadcast_tensors(input, target)
- if weight is not None:
- if weight.size() != input.size():
- raise ValueError("Weights and input must have the same size.")
- # Perform weighted MSE loss manually
- squared_errors = torch.pow(expanded_input - expanded_target, 2)
- weighted_squared_errors = squared_errors * weight
- if reduction == "none":
- return weighted_squared_errors
- elif reduction == "sum":
- return torch.sum(weighted_squared_errors)
- elif reduction == "mean":
- return torch.sum(weighted_squared_errors) / torch.sum(weight)
- else:
- raise ValueError(
- f"Invalid reduction mode: {reduction}. Expected one of 'none', 'mean', 'sum'."
- )
- else:
- return torch._C._nn.mse_loss(
- expanded_input, expanded_target, _Reduction.get_enum(reduction)
- )
- def margin_ranking_loss(
- input1: Tensor,
- input2: Tensor,
- target: Tensor,
- margin: float = 0,
- size_average: Optional[bool] = None,
- reduce: Optional[bool] = None,
- reduction: str = "mean",
- ) -> Tensor: # noqa: D400,D402
- r"""Compute the margin ranking loss.
- See :class:`~torch.nn.MarginRankingLoss` for details.
- Args:
- input1 (Tensor): Predicted values.
- input2 (Tensor): Predicted values.
- target (Tensor): Ground truth values.
- size_average (bool, optional): Deprecated (see :attr:`reduction`).
- reduce (bool, optional): Deprecated (see :attr:`reduction`).
- reduction (str, optional): Specifies the reduction to apply to the output:
- 'none' | 'mean' | 'sum'. 'mean': the mean of the output is taken.
- 'sum': the output will be summed. 'none': no reduction will be applied.
- Default: 'mean'.
- Returns:
- Tensor: Margin ranking loss.
- """
- if has_torch_function_variadic(input1, input2, target):
- return handle_torch_function(
- margin_ranking_loss,
- (input1, input2, target),
- input1,
- input2,
- target,
- margin=margin,
- size_average=size_average,
- reduce=reduce,
- reduction=reduction,
- )
- if size_average is not None or reduce is not None:
- reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
- else:
- reduction_enum = _Reduction.get_enum(reduction)
- if input1.dim() != input2.dim() or input1.dim() != target.dim():
- raise RuntimeError(
- f"margin_ranking_loss : All input tensors should have same dimension but got sizes: "
- f"input1: {input1.size()}, input2: {input2.size()}, target: {target.size()} "
- )
- return torch.margin_ranking_loss(input1, input2, target, margin, reduction_enum)
- def hinge_embedding_loss(
- input: Tensor,
- target: Tensor,
- margin: float = 1.0,
- size_average: Optional[bool] = None,
- reduce: Optional[bool] = None,
- reduction: str = "mean",
- ) -> Tensor: # noqa: D400,D402
- r"""Compute the hinge embedding loss.
- See :class:`~torch.nn.HingeEmbeddingLoss` for details.
- Args:
- input (Tensor): Predicted values.
- target (Tensor): Ground truth values.
- margin (float, optional): Margin for hinge loss. Has a default value of 1.
- size_average (bool, optional): Deprecated (see :attr:`reduction`).
- reduce (bool, optional): Deprecated (see :attr:`reduction`).
- reduction (str, optional): Specifies the reduction to apply to the output:
- 'none' | 'mean' | 'sum'. 'mean': the mean of the output is taken.
- 'sum': the output will be summed. 'none': no reduction will be applied.
- Default: 'mean'.
- Returns:
- Tensor: Hinge embedding loss.
- """
- if has_torch_function_variadic(input, target):
- return handle_torch_function(
- hinge_embedding_loss,
- (input, target),
- input,
- target,
- margin=margin,
- size_average=size_average,
- reduce=reduce,
- reduction=reduction,
- )
- if size_average is not None or reduce is not None:
- reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
- else:
- reduction_enum = _Reduction.get_enum(reduction)
- return torch.hinge_embedding_loss(input, target, margin, reduction_enum)
- def multilabel_margin_loss(
- input: Tensor,
- target: Tensor,
- size_average: Optional[bool] = None,
- reduce: Optional[bool] = None,
- reduction: str = "mean",
- ) -> Tensor: # noqa: D400,D402
- r"""Compute the multilabel margin loss.
- See :class:`~torch.nn.MultiLabelMarginLoss` for details.
- Args:
- input (Tensor): Predicted values.
- target (Tensor): Ground truth values.
- size_average (bool, optional): Deprecated (see :attr:`reduction`).
- reduce (bool, optional): Deprecated (see :attr:`reduction`).
- reduction (str, optional): Specifies the reduction to apply to the output:
- 'none' | 'mean' | 'sum'. 'mean': the mean of the output is taken.
- 'sum': the output will be summed. 'none': no reduction will be applied.
- Default: 'mean'.
- Returns:
- Tensor: Mutilabel margin loss.
- """
- if has_torch_function_variadic(input, target):
- return handle_torch_function(
- multilabel_margin_loss,
- (input, target),
- input,
- target,
- size_average=size_average,
- reduce=reduce,
- reduction=reduction,
- )
- if size_average is not None or reduce is not None:
- reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
- else:
- reduction_enum = _Reduction.get_enum(reduction)
- return torch._C._nn.multilabel_margin_loss(input, target, reduction_enum)
- def soft_margin_loss(
- input: Tensor,
- target: Tensor,
- size_average: Optional[bool] = None,
- reduce: Optional[bool] = None,
- reduction: str = "mean",
- ) -> Tensor: # noqa: D400,D402
- r"""Compute the soft margin loss.
- See :class:`~torch.nn.SoftMarginLoss` for details.
- Args:
- input (Tensor): Predicted values.
- target (Tensor): Ground truth values.
- size_average (bool, optional): Deprecated (see :attr:`reduction`).
- reduce (bool, optional): Deprecated (see :attr:`reduction`).
- reduction (str, optional): Specifies the reduction to apply to the output:
- 'none' | 'mean' | 'sum'. 'mean': the mean of the output is taken.
- 'sum': the output will be summed. 'none': no reduction will be applied.
- Default: 'mean'.
- Returns:
- Tensor: Soft margin loss.
- """
- if has_torch_function_variadic(input, target):
- return handle_torch_function(
- soft_margin_loss,
- (input, target),
- input,
- target,
- size_average=size_average,
- reduce=reduce,
- reduction=reduction,
- )
- if size_average is not None or reduce is not None:
- reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
- else:
- reduction_enum = _Reduction.get_enum(reduction)
- return torch._C._nn.soft_margin_loss(input, target, reduction_enum)
- def multilabel_soft_margin_loss(
- input: Tensor,
- target: Tensor,
- weight: Optional[Tensor] = None,
- size_average: Optional[bool] = None,
- reduce: Optional[bool] = None,
- reduction: str = "mean",
- ) -> Tensor: # noqa: D400,D402
- r"""Compute the multilabel soft margin loss.
- See :class:`~torch.nn.MultiLabelSoftMarginLoss` for details.
- Args:
- input (Tensor): Predicted values.
- target (Tensor): Ground truth values.
- size_average (bool, optional): Deprecated (see :attr:`reduction`).
- reduce (bool, optional): Deprecated (see :attr:`reduction`).
- reduction (str, optional): Specifies the reduction to apply to the output:
- 'none' | 'mean' | 'sum'. 'mean': the mean of the output is taken.
- 'sum': the output will be summed. 'none': no reduction will be applied.
- Default: 'mean'.
- Returns:
- Tensor: Mutilabel soft margin loss.
- """
- if has_torch_function_variadic(input, target, weight):
- return handle_torch_function(
- multilabel_soft_margin_loss,
- (input, target, weight),
- input,
- target,
- weight=weight,
- size_average=size_average,
- reduce=reduce,
- reduction=reduction,
- )
- if size_average is not None or reduce is not None:
- reduction = _Reduction.legacy_get_string(size_average, reduce)
- loss = -(target * logsigmoid(input) + (1 - target) * logsigmoid(-input))
- if weight is not None:
- loss = loss * weight
- class_dim = input.dim() - 1
- C = input.size(class_dim)
- loss = loss.sum(dim=class_dim) / C # only return N loss values
- if reduction == "none":
- ret = loss
- elif reduction == "mean":
- ret = loss.mean()
- elif reduction == "sum":
- ret = loss.sum()
- else:
- ret = input
- raise ValueError(reduction + " is not valid")
- return ret
- def cosine_embedding_loss(
- input1: Tensor,
- input2: Tensor,
- target: Tensor,
- margin: float = 0,
- size_average: Optional[bool] = None,
- reduce: Optional[bool] = None,
- reduction: str = "mean",
- ) -> Tensor: # noqa: D400,D402
- r"""Compute the cosine embedding loss.
- See :class:`~torch.nn.CosineEmbeddingLoss` for details.
- Args:
- input1 (Tensor): Predicted values.
- input2 (Tensor): Predicted values.
- target (Tensor): Ground truth values.
- margin (float, optional): Margin for cosine embedding. Has a default value of 0.
- size_average (bool, optional): Deprecated (see :attr:`reduction`).
- reduce (bool, optional): Deprecated (see :attr:`reduction`).
- reduction (str, optional): Specifies the reduction to apply to the output:
- 'none' | 'mean' | 'sum'. 'mean': the mean of the output is taken.
- 'sum': the output will be summed. 'none': no reduction will be applied.
- Default: 'mean'.
- Returns:
- Tensor: Cosine embedding loss.
- """
- if has_torch_function_variadic(input1, input2, target):
- return handle_torch_function(
- cosine_embedding_loss,
- (input1, input2, target),
- input1,
- input2,
- target,
- margin=margin,
- size_average=size_average,
- reduce=reduce,
- reduction=reduction,
- )
- if size_average is not None or reduce is not None:
- reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
- else:
- reduction_enum = _Reduction.get_enum(reduction)
- return torch.cosine_embedding_loss(input1, input2, target, margin, reduction_enum)
- def multi_margin_loss(
- input: Tensor,
- target: Tensor,
- p: int = 1,
- margin: float = 1.0,
- weight: Optional[Tensor] = None,
- size_average: Optional[bool] = None,
- reduce: Optional[bool] = None,
- reduction: str = "mean",
- ) -> Tensor: # noqa: D400,D402
- r"""Compute the multi margin loss, with optional weighting.
- See :class:`~torch.nn.MultiMarginLoss` for details.
- Args:
- input (Tensor): Predicted values.
- target (Tensor): Ground truth values.
- p (int, optional): Has a default value of 1. 1 and 2 are the only supported values.
- margin (float, optional): Margin for multi margin loss. Has a default value of 1.
- weight (Tensor, optional): Weights for each sample. Default: None.
- size_average (bool, optional): Deprecated (see :attr:`reduction`).
- reduce (bool, optional): Deprecated (see :attr:`reduction`).
- reduction (str, optional): Specifies the reduction to apply to the output:
- 'none' | 'mean' | 'sum'. 'mean': the mean of the output is taken.
- 'sum': the output will be summed. 'none': no reduction will be applied.
- Default: 'mean'.
- Returns:
- Tensor: Multi margin loss (optionally weighted).
- """
- if has_torch_function_variadic(input, target, weight):
- return handle_torch_function(
- multi_margin_loss,
- (input, target, weight),
- input,
- target,
- p=p,
- margin=margin,
- weight=weight,
- size_average=size_average,
- reduce=reduce,
- reduction=reduction,
- )
- if size_average is not None or reduce is not None:
- reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
- else:
- reduction_enum = _Reduction.get_enum(reduction)
- if p != 1 and p != 2:
- raise ValueError("only p == 1 and p == 2 supported")
- if weight is not None:
- if weight.dim() != 1:
- raise ValueError("weight must be one-dimensional")
- return torch._C._nn.multi_margin_loss(
- input, target, p, margin, weight, reduction_enum
- )
- pixel_shuffle = _add_docstr(
- torch.pixel_shuffle,
- r"""
- pixel_shuffle(input, upscale_factor) -> Tensor
- Rearranges elements in a tensor of shape :math:`(*, C \times r^2, H, W)` to a
- tensor of shape :math:`(*, C, H \times r, W \times r)`, where r is the :attr:`upscale_factor`.
- See :class:`~torch.nn.PixelShuffle` for details.
- Args:
- input (Tensor): the input tensor
- upscale_factor (int): factor to increase spatial resolution by
- Examples::
- >>> input = torch.randn(1, 9, 4, 4)
- >>> output = torch.nn.functional.pixel_shuffle(input, 3)
- >>> print(output.size())
- torch.Size([1, 1, 12, 12])
- """,
- )
- pixel_unshuffle = _add_docstr(
- torch.pixel_unshuffle,
- r"""
- pixel_unshuffle(input, downscale_factor) -> Tensor
- Reverses the :class:`~torch.nn.PixelShuffle` operation by rearranging elements in a
- tensor of shape :math:`(*, C, H \times r, W \times r)` to a tensor of shape
- :math:`(*, C \times r^2, H, W)`, where r is the :attr:`downscale_factor`.
- See :class:`~torch.nn.PixelUnshuffle` for details.
- Args:
- input (Tensor): the input tensor
- downscale_factor (int): factor to increase spatial resolution by
- Examples::
- >>> input = torch.randn(1, 1, 12, 12)
- >>> output = torch.nn.functional.pixel_unshuffle(input, 3)
- >>> print(output.size())
- torch.Size([1, 9, 4, 4])
- """,
- )
- channel_shuffle = _add_docstr(
- torch.channel_shuffle,
- r"""
- channel_shuffle(input, groups) -> Tensor
- Divide the channels in a tensor of shape :math:`(*, C , H, W)`
- into g groups and rearrange them as :math:`(*, C \frac g, g, H, W)`,
- while keeping the original tensor shape.
- See :class:`~torch.nn.ChannelShuffle` for details.
- Args:
- input (Tensor): the input tensor
- groups (int): number of groups to divide channels in and rearrange.
- Examples::
- >>> input = torch.randn(1, 4, 2, 2)
- >>> print(input)
- [[[[1, 2],
- [3, 4]],
- [[5, 6],
- [7, 8]],
- [[9, 10],
- [11, 12]],
- [[13, 14],
- [15, 16]],
- ]]
- >>> output = torch.nn.functional.channel_shuffle(input, 2)
- >>> print(output)
- [[[[1, 2],
- [3, 4]],
- [[9, 10],
- [11, 12]],
- [[5, 6],
- [7, 8]],
- [[13, 14],
- [15, 16]],
- ]]
- """,
- )
- native_channel_shuffle = _add_docstr(
- torch.native_channel_shuffle,
- r"""
- native_channel_shuffle(input, groups) -> Tensor
- Native kernel level implementation of the `channel_shuffle`.
- This function might become private in future releases, use with caution.
- Divide the channels in a tensor of shape :math:`(*, C , H, W)`
- into g groups and rearrange them as :math:`(*, C \frac g, g, H, W)`,
- while keeping the original tensor shape.
- See :class:`~torch.nn.ChannelShuffle` for details.
- Args:
- input (Tensor): the input tensor
- groups (int): number of groups to divide channels in and rearrange.
- Examples::
- >>> input = torch.randn(1, 4, 2, 2)
- >>> print(input)
- [[[[1, 2],
- [3, 4]],
- [[5, 6],
- [7, 8]],
- [[9, 10],
- [11, 12]],
- [[13, 14],
- [15, 16]],
- ]]
- >>> output = torch.nn.functional.native_channel_shuffle(input, 2)
- >>> print(output)
- [[[[1, 2],
- [3, 4]],
- [[9, 10],
- [11, 12]],
- [[5, 6],
- [7, 8]],
- [[13, 14],
- [15, 16]],
- ]]
- """,
- )
- @_overload
- def upsample( # noqa: F811
- input: Tensor,
- size: Optional[int] = None,
- scale_factor: Optional[float] = None,
- mode: str = "nearest",
- align_corners: Optional[bool] = None,
- ) -> Tensor: # noqa: B950
- pass
- @_overload
- def upsample( # noqa: F811
- input: Tensor,
- size: Optional[list[int]] = None,
- scale_factor: Optional[float] = None,
- mode: str = "nearest",
- align_corners: Optional[bool] = None,
- ) -> Tensor: # noqa: B950
- pass
- def upsample( # noqa: F811
- input,
- size=None,
- scale_factor=None,
- mode="nearest",
- align_corners=None,
- ):
- r"""Upsample input.
- Provided tensor is upsampled to either the given :attr:`size` or the given
- :attr:`scale_factor`
- .. warning::
- This function is deprecated in favor of :func:`torch.nn.functional.interpolate`.
- This is equivalent with ``nn.functional.interpolate(...)``.
- Note:
- {backward_reproducibility_note}
- The algorithm used for upsampling is determined by :attr:`mode`.
- Currently temporal, spatial and volumetric upsampling are supported, i.e.
- expected inputs are 3-D, 4-D or 5-D in shape.
- The input dimensions are interpreted in the form:
- `mini-batch x channels x [optional depth] x [optional height] x width`.
- The modes available for upsampling are: `nearest`, `linear` (3D-only),
- `bilinear`, `bicubic` (4D-only), `trilinear` (5D-only)
- Args:
- input (Tensor): the input tensor
- size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int]):
- output spatial size.
- scale_factor (float or Tuple[float]): multiplier for spatial size. Has to match input size if it is a tuple.
- mode (str): algorithm used for upsampling:
- ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |
- ``'trilinear'``. Default: ``'nearest'``
- align_corners (bool, optional): Geometrically, we consider the pixels of the
- input and output as squares rather than points.
- If set to ``True``, the input and output tensors are aligned by the
- center points of their corner pixels, preserving the values at the corner pixels.
- If set to ``False``, the input and output tensors are aligned by the corner
- points of their corner pixels, and the interpolation uses edge value padding
- for out-of-boundary values, making this operation *independent* of input size
- when :attr:`scale_factor` is kept the same. This only has an effect when :attr:`mode`
- is ``'linear'``, ``'bilinear'``, ``'bicubic'`` or ``'trilinear'``.
- Default: ``False``
- .. note::
- With ``mode='bicubic'``, it's possible to cause overshoot, in other words it can produce
- negative values or values greater than 255 for images.
- Explicitly call ``result.clamp(min=0, max=255)`` if you want to reduce the overshoot
- when displaying the image.
- .. warning::
- With ``align_corners = True``, the linearly interpolating modes
- (`linear`, `bilinear`, and `trilinear`) don't proportionally align the
- output and input pixels, and thus the output values can depend on the
- input size. This was the default behavior for these modes up to version
- 0.3.1. Since then, the default behavior is ``align_corners = False``.
- See :class:`~torch.nn.Upsample` for concrete examples on how this
- affects the outputs.
- """
- warnings.warn(
- "`nn.functional.upsample` is deprecated. "
- "Use `nn.functional.interpolate` instead.",
- stacklevel=2,
- )
- return interpolate(input, size, scale_factor, mode, align_corners)
- if upsample.__doc__:
- upsample.__doc__ = upsample.__doc__.format(**reproducibility_notes)
- def _is_integer(x) -> bool:
- r"""Type check the input number is an integer.
- Will return True for int, SymInt, Numpy integers and Tensors with integer elements.
- """
- if isinstance(x, (int, torch.SymInt)):
- return True
- if np is not None and isinstance(x, np.integer):
- return True
- return isinstance(x, Tensor) and not x.is_floating_point()
- @_overload
- def interpolate( # noqa: F811
- input: Tensor,
- size: Optional[int] = None,
- scale_factor: Optional[list[float]] = None,
- mode: str = "nearest",
- align_corners: Optional[bool] = None,
- recompute_scale_factor: Optional[bool] = None,
- antialias: bool = False,
- ) -> Tensor: # noqa: B950
- pass
- @_overload
- def interpolate( # noqa: F811
- input: Tensor,
- size: Optional[list[int]] = None,
- scale_factor: Optional[list[float]] = None,
- mode: str = "nearest",
- align_corners: Optional[bool] = None,
- recompute_scale_factor: Optional[bool] = None,
- antialias: bool = False,
- ) -> Tensor: # noqa: B950
- pass
- @_overload
- def interpolate( # noqa: F811
- input: Tensor,
- size: Optional[int] = None,
- scale_factor: Optional[float] = None,
- mode: str = "nearest",
- align_corners: Optional[bool] = None,
- recompute_scale_factor: Optional[bool] = None,
- antialias: bool = False,
- ) -> Tensor: # noqa: B950
- pass
- @_overload
- def interpolate( # noqa: F811
- input: Tensor,
- size: Optional[list[int]] = None,
- scale_factor: Optional[float] = None,
- mode: str = "nearest",
- align_corners: Optional[bool] = None,
- recompute_scale_factor: Optional[bool] = None,
- antialias: bool = False,
- ) -> Tensor:
- pass
- def interpolate( # noqa: F811
- input: Tensor,
- size: Optional[int] = None,
- scale_factor: Optional[list[float]] = None,
- mode: str = "nearest",
- align_corners: Optional[bool] = None,
- recompute_scale_factor: Optional[bool] = None,
- antialias: bool = False,
- ) -> Tensor: # noqa: B950
- r"""Down/up samples the input.
- Tensor interpolated to either the given :attr:`size` or the given
- :attr:`scale_factor`
- The algorithm used for interpolation is determined by :attr:`mode`.
- Currently temporal, spatial and volumetric sampling are supported, i.e.
- expected inputs are 3-D, 4-D or 5-D in shape.
- The input dimensions are interpreted in the form:
- `mini-batch x channels x [optional depth] x [optional height] x width`.
- The modes available for resizing are: `nearest`, `linear` (3D-only),
- `bilinear`, `bicubic` (4D-only), `trilinear` (5D-only), `area`, `nearest-exact`
- Args:
- input (Tensor): the input tensor
- size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int]):
- output spatial size.
- scale_factor (float or Tuple[float]): multiplier for spatial size. If `scale_factor` is a tuple,
- its length has to match the number of spatial dimensions; `input.dim() - 2`.
- mode (str): algorithm used for upsampling:
- ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |
- ``'trilinear'`` | ``'area'`` | ``'nearest-exact'``. Default: ``'nearest'``
- align_corners (bool, optional): Geometrically, we consider the pixels of the
- input and output as squares rather than points.
- If set to ``True``, the input and output tensors are aligned by the
- center points of their corner pixels, preserving the values at the corner pixels.
- If set to ``False``, the input and output tensors are aligned by the corner
- points of their corner pixels, and the interpolation uses edge value padding
- for out-of-boundary values, making this operation *independent* of input size
- when :attr:`scale_factor` is kept the same. This only has an effect when :attr:`mode`
- is ``'linear'``, ``'bilinear'``, ``'bicubic'`` or ``'trilinear'``.
- Default: ``False``
- recompute_scale_factor (bool, optional): recompute the scale_factor for use in the
- interpolation calculation. If `recompute_scale_factor` is ``True``, then
- `scale_factor` must be passed in and `scale_factor` is used to compute the
- output `size`. The computed output `size` will be used to infer new scales for
- the interpolation. Note that when `scale_factor` is floating-point, it may differ
- from the recomputed `scale_factor` due to rounding and precision issues.
- If `recompute_scale_factor` is ``False``, then `size` or `scale_factor` will
- be used directly for interpolation. Default: ``None``.
- antialias (bool, optional): flag to apply anti-aliasing. Default: ``False``. Using anti-alias
- option together with ``align_corners=False``, interpolation result would match Pillow
- result for downsampling operation. Supported modes: ``'bilinear'``, ``'bicubic'``.
- .. note::
- With ``mode='bicubic'``, it's possible to cause overshoot. For some dtypes, it can produce
- negative values or values greater than 255 for images. Explicitly call ``result.clamp(min=0,max=255)``
- if you want to reduce the overshoot when displaying the image.
- For ``uint8`` inputs, it already performs saturating cast operation. So, no manual `clamp` operation is needed.
- .. note::
- Mode ``mode='nearest-exact'`` matches Scikit-Image and PIL nearest neighbours interpolation
- algorithms and fixes known issues with ``mode='nearest'``. This mode is introduced to keep
- backward compatibility.
- Mode ``mode='nearest'`` matches buggy OpenCV's ``INTER_NEAREST`` interpolation algorithm.
- .. note::
- The gradients for the dtype ``float16`` on CUDA may be inaccurate in the upsample operation
- when using modes ``['linear', 'bilinear', 'bicubic', 'trilinear', 'area']``.
- For more details, please refer to the discussion in
- `issue#104157 <https://github.com/pytorch/pytorch/issues/104157>`_.
- Note:
- {backward_reproducibility_note}
- """
- if has_torch_function_unary(input):
- return handle_torch_function(
- interpolate,
- (input,),
- input,
- size=size,
- scale_factor=scale_factor,
- mode=mode,
- align_corners=align_corners,
- recompute_scale_factor=recompute_scale_factor,
- antialias=antialias,
- )
- if mode in ("nearest", "area", "nearest-exact"):
- if align_corners is not None:
- raise ValueError(
- "align_corners option can only be set with the "
- "interpolating modes: linear | bilinear | bicubic | trilinear"
- )
- else:
- if align_corners is None:
- align_corners = False
- dim = input.dim() - 2 # Number of spatial dimensions.
- # Process size and scale_factor. Validate that exactly one is set.
- # Validate its length if it is a list, or expand it if it is a scalar.
- # After this block, exactly one of output_size and scale_factors will
- # be non-None, and it will be a list (or tuple).
- if size is not None and scale_factor is not None:
- raise ValueError("only one of size or scale_factor should be defined")
- elif size is not None:
- assert scale_factor is None
- scale_factors = None
- if isinstance(size, (list, tuple)):
- if len(size) != dim:
- raise ValueError(
- "Input and output must have the same number of spatial dimensions, but got "
- f"input with spatial dimensions of {list(input.shape[2:])} and output size of {size}. "
- "Please provide input tensor in (N, C, d1, d2, ...,dK) format and "
- "output size in (o1, o2, ...,oK) format."
- )
- if not torch.jit.is_scripting():
- if not all(_is_integer(x) for x in size):
- raise TypeError(
- "expected size to be one of int or Tuple[int] or Tuple[int, int] or "
- f"Tuple[int, int, int], but got size with types {[type(x) for x in size]}"
- )
- output_size = size
- else:
- output_size = [size for _ in range(dim)]
- elif scale_factor is not None:
- assert size is None
- output_size = None
- if isinstance(scale_factor, (list, tuple)):
- if len(scale_factor) != dim:
- raise ValueError(
- "Input and scale_factor must have the same number of spatial dimensions, but "
- f"got input with spatial dimensions of {list(input.shape[2:])} and "
- f"scale_factor of shape {scale_factor}. "
- "Please provide input tensor in (N, C, d1, d2, ...,dK) format and "
- "scale_factor in (s1, s2, ...,sK) format."
- )
- scale_factors = scale_factor
- else:
- scale_factors = [scale_factor for _ in range(dim)]
- else:
- raise ValueError("either size or scale_factor should be defined")
- if (
- recompute_scale_factor is not None
- and recompute_scale_factor
- and size is not None
- ):
- raise ValueError(
- "recompute_scale_factor is not meaningful with an explicit size."
- )
- # "area" mode always requires an explicit size rather than scale factor.
- # Reuse the recompute_scale_factor code path.
- if mode == "area" and output_size is None:
- recompute_scale_factor = True
- if recompute_scale_factor is not None and recompute_scale_factor:
- # We compute output_size here, then un-set scale_factors.
- # The C++ code will recompute it based on the (integer) output size.
- assert scale_factors is not None
- if not torch.jit.is_scripting() and torch._C._get_tracing_state():
- # make scale_factor a tensor in tracing so constant doesn't get baked in
- output_size = [
- (
- torch.floor(
- (
- input.size(i + 2).float()
- * torch.tensor(scale_factors[i], dtype=torch.float32)
- ).float()
- )
- )
- for i in range(dim)
- ]
- elif torch.jit.is_scripting():
- output_size = [
- int(math.floor(float(input.size(i + 2)) * scale_factors[i]))
- for i in range(dim)
- ]
- else:
- output_size = [
- _sym_int(input.size(i + 2) * scale_factors[i]) for i in range(dim)
- ]
- scale_factors = None
- if antialias and not (mode in ("bilinear", "bicubic") and input.ndim == 4):
- raise ValueError(
- "Anti-alias option is restricted to bilinear and bicubic modes and requires a 4-D tensor as input"
- )
- if input.dim() == 3 and mode == "nearest":
- return torch._C._nn.upsample_nearest1d(input, output_size, scale_factors)
- if input.dim() == 4 and mode == "nearest":
- return torch._C._nn.upsample_nearest2d(input, output_size, scale_factors)
- if input.dim() == 5 and mode == "nearest":
- return torch._C._nn.upsample_nearest3d(input, output_size, scale_factors)
- if input.dim() == 3 and mode == "nearest-exact":
- return torch._C._nn._upsample_nearest_exact1d(input, output_size, scale_factors)
- if input.dim() == 4 and mode == "nearest-exact":
- return torch._C._nn._upsample_nearest_exact2d(input, output_size, scale_factors)
- if input.dim() == 5 and mode == "nearest-exact":
- return torch._C._nn._upsample_nearest_exact3d(input, output_size, scale_factors)
- if input.dim() == 3 and mode == "area":
- assert output_size is not None
- return adaptive_avg_pool1d(input, output_size)
- if input.dim() == 4 and mode == "area":
- assert output_size is not None
- return adaptive_avg_pool2d(input, output_size)
- if input.dim() == 5 and mode == "area":
- assert output_size is not None
- return adaptive_avg_pool3d(input, output_size)
- if input.dim() == 3 and mode == "linear":
- assert align_corners is not None
- return torch._C._nn.upsample_linear1d(
- input, output_size, align_corners, scale_factors
- )
- if input.dim() == 4 and mode == "bilinear":
- assert align_corners is not None
- if antialias:
- return torch._C._nn._upsample_bilinear2d_aa(
- input, output_size, align_corners, scale_factors
- )
- # Two levels are necessary to prevent TorchScript from touching
- # are_deterministic_algorithms_enabled.
- if not torch.jit.is_scripting():
- if not input.is_cpu and torch.are_deterministic_algorithms_enabled():
- # Use slow decomp whose backward will be in terms of index_put
- # importlib is required because the import cannot be top level
- # (cycle) and cannot be nested (TS doesn't support)
- return importlib.import_module(
- "torch._decomp.decompositions"
- )._upsample_linear_vec(input, output_size, align_corners, scale_factors)
- return torch._C._nn.upsample_bilinear2d(
- input, output_size, align_corners, scale_factors
- )
- if input.dim() == 5 and mode == "trilinear":
- assert align_corners is not None
- # Two levels are necessary to prevent TorchScript from touching
- # are_deterministic_algorithms_enabled.
- if not torch.jit.is_scripting():
- if not input.is_cpu and torch.are_deterministic_algorithms_enabled():
- # Use slow decomp whose backward will be in terms of index_put
- # importlib is required because the import cannot be top level
- # (cycle) and cannot be nested (TS doesn't support)
- return importlib.import_module(
- "torch._decomp.decompositions"
- )._upsample_linear_vec(input, output_size, align_corners, scale_factors)
- return torch._C._nn.upsample_trilinear3d(
- input, output_size, align_corners, scale_factors
- )
- if input.dim() == 4 and mode == "bicubic":
- assert align_corners is not None
- if antialias:
- return torch._C._nn._upsample_bicubic2d_aa(
- input, output_size, align_corners, scale_factors
- )
- return torch._C._nn.upsample_bicubic2d(
- input, output_size, align_corners, scale_factors
- )
- if input.dim() == 3 and mode == "bilinear":
- raise NotImplementedError("Got 3D input, but bilinear mode needs 4D input")
- if input.dim() == 3 and mode == "trilinear":
- raise NotImplementedError("Got 3D input, but trilinear mode needs 5D input")
- if input.dim() == 4 and mode == "linear":
- raise NotImplementedError("Got 4D input, but linear mode needs 3D input")
- if input.dim() == 4 and mode == "trilinear":
- raise NotImplementedError("Got 4D input, but trilinear mode needs 5D input")
- if input.dim() == 5 and mode == "linear":
- raise NotImplementedError("Got 5D input, but linear mode needs 3D input")
- if input.dim() == 5 and mode == "bilinear":
- raise NotImplementedError("Got 5D input, but bilinear mode needs 4D input")
- raise NotImplementedError(
- "Input Error: Only 3D, 4D and 5D input Tensors supported"
- f" (got {input.dim()}D) for the modes: nearest | linear | bilinear | bicubic | trilinear | area | nearest-exact"
- f" (got {mode})"
- )
- if interpolate.__doc__:
- interpolate.__doc__ = interpolate.__doc__.format(**reproducibility_notes)
- @_overload
- def upsample_nearest( # noqa: F811
- input: Tensor,
- size: Optional[int] = None,
- scale_factor: Optional[float] = None,
- ) -> Tensor:
- pass
- @_overload
- def upsample_nearest( # noqa: F811
- input: Tensor,
- size: Optional[list[int]] = None,
- scale_factor: Optional[float] = None,
- ) -> Tensor:
- pass
- def upsample_nearest(input, size=None, scale_factor=None): # noqa: F811
- r"""Upsamples the input, using nearest neighbours' pixel values.
- .. warning::
- This function is deprecated in favor of :func:`torch.nn.functional.interpolate`.
- This is equivalent with ``nn.functional.interpolate(..., mode='nearest')``.
- Currently spatial and volumetric upsampling are supported (i.e. expected
- inputs are 4 or 5 dimensional).
- Args:
- input (Tensor): input
- size (int or Tuple[int, int] or Tuple[int, int, int]): output spatia
- size.
- scale_factor (int): multiplier for spatial size. Has to be an integer.
- Note:
- {backward_reproducibility_note}
- """
- # DeprecationWarning is ignored by default
- warnings.warn(
- "`nn.functional.upsample_nearest` is deprecated. "
- "Use `nn.functional.interpolate` instead.",
- stacklevel=2,
- )
- return interpolate(input, size, scale_factor, mode="nearest")
- if upsample_nearest.__doc__:
- upsample_nearest.__doc__ = upsample_nearest.__doc__.format(**reproducibility_notes)
- @_overload
- def upsample_bilinear( # noqa: F811
- input: Tensor,
- size: Optional[int] = None,
- scale_factor: Optional[float] = None,
- ) -> Tensor:
- pass
- @_overload
- def upsample_bilinear( # noqa: F811
- input: Tensor,
- size: Optional[list[int]] = None,
- scale_factor: Optional[float] = None,
- ) -> Tensor:
- pass
- @_overload
- def upsample_bilinear( # noqa: F811
- input: Tensor,
- size: Optional[int] = None,
- scale_factor: Optional[list[float]] = None,
- ) -> Tensor:
- pass
- @_overload
- def upsample_bilinear( # noqa: F811
- input: Tensor,
- size: Optional[list[int]] = None,
- scale_factor: Optional[list[float]] = None,
- ) -> Tensor:
- pass
- def upsample_bilinear(input, size=None, scale_factor=None): # noqa: F811
- r"""Upsamples the input, using bilinear upsampling.
- .. warning::
- This function is deprecated in favor of :func:`torch.nn.functional.interpolate`.
- This is equivalent with
- ``nn.functional.interpolate(..., mode='bilinear', align_corners=True)``.
- Expected inputs are spatial (4 dimensional). Use `upsample_trilinear` for
- volumetric (5 dimensional) inputs.
- Args:
- input (Tensor): input
- size (int or Tuple[int, int]): output spatial size.
- scale_factor (int or Tuple[int, int]): multiplier for spatial size
- Note:
- {backward_reproducibility_note}
- """
- # DeprecationWarning is ignored by default
- warnings.warn(
- "`nn.functional.upsample_bilinear` is deprecated. "
- "Use `nn.functional.interpolate` instead.",
- stacklevel=2,
- )
- return interpolate(input, size, scale_factor, mode="bilinear", align_corners=True)
- if upsample_bilinear.__doc__:
- upsample_bilinear.__doc__ = upsample_bilinear.__doc__.format(
- **reproducibility_notes
- )
- GRID_SAMPLE_INTERPOLATION_MODES = {
- "bilinear": 0,
- "nearest": 1,
- "bicubic": 2,
- }
- GRID_SAMPLE_PADDING_MODES = {
- "zeros": 0,
- "border": 1,
- "reflection": 2,
- }
- def grid_sample(
- input: Tensor,
- grid: Tensor,
- mode: str = "bilinear",
- padding_mode: str = "zeros",
- align_corners: Optional[bool] = None,
- ) -> Tensor:
- r"""Compute grid sample.
- Given an :attr:`input` and a flow-field :attr:`grid`, computes the
- ``output`` using :attr:`input` values and pixel locations from :attr:`grid`.
- Currently, only spatial (4-D) and volumetric (5-D) :attr:`input` are
- supported.
- In the spatial (4-D) case, for :attr:`input` with shape
- :math:`(N, C, H_\text{in}, W_\text{in})` and :attr:`grid` with shape
- :math:`(N, H_\text{out}, W_\text{out}, 2)`, the output will have shape
- :math:`(N, C, H_\text{out}, W_\text{out})`.
- For each output location ``output[n, :, h, w]``, the size-2 vector
- ``grid[n, h, w]`` specifies :attr:`input` pixel locations ``x`` and ``y``,
- which are used to interpolate the output value ``output[n, :, h, w]``.
- In the case of 5D inputs, ``grid[n, d, h, w]`` specifies the
- ``x``, ``y``, ``z`` pixel locations for interpolating
- ``output[n, :, d, h, w]``. :attr:`mode` argument specifies ``nearest`` or
- ``bilinear`` interpolation method to sample the input pixels.
- :attr:`grid` specifies the sampling pixel locations normalized by the
- :attr:`input` spatial dimensions. Therefore, it should have most values in
- the range of ``[-1, 1]``. For example, values ``x = -1, y = -1`` is the
- left-top pixel of :attr:`input`, and values ``x = 1, y = 1`` is the
- right-bottom pixel of :attr:`input`.
- If :attr:`grid` has values outside the range of ``[-1, 1]``, the corresponding
- outputs are handled as defined by :attr:`padding_mode`. Options are
- * ``padding_mode="zeros"``: use ``0`` for out-of-bound grid locations,
- * ``padding_mode="border"``: use border values for out-of-bound grid locations,
- * ``padding_mode="reflection"``: use values at locations reflected by
- the border for out-of-bound grid locations. For location far away
- from the border, it will keep being reflected until becoming in bound,
- e.g., (normalized) pixel location ``x = -3.5`` reflects by border ``-1``
- and becomes ``x' = 1.5``, then reflects by border ``1`` and becomes
- ``x'' = -0.5``.
- Note:
- This function is often used in conjunction with :func:`affine_grid`
- to build `Spatial Transformer Networks`_ .
- Note:
- When using the CUDA backend, this operation may induce nondeterministic
- behaviour in its backward pass that is not easily switched off.
- Please see the notes on :doc:`/notes/randomness` for background.
- Note:
- NaN values in :attr:`grid` would be interpreted as ``-1``.
- Args:
- input (Tensor): input of shape :math:`(N, C, H_\text{in}, W_\text{in})` (4-D case)
- or :math:`(N, C, D_\text{in}, H_\text{in}, W_\text{in})` (5-D case)
- grid (Tensor): flow-field of shape :math:`(N, H_\text{out}, W_\text{out}, 2)` (4-D case)
- or :math:`(N, D_\text{out}, H_\text{out}, W_\text{out}, 3)` (5-D case)
- mode (str): interpolation mode to calculate output values
- ``'bilinear'`` | ``'nearest'`` | ``'bicubic'``. Default: ``'bilinear'``
- Note: ``mode='bicubic'`` supports only 4-D input.
- When ``mode='bilinear'`` and the input is 5-D, the interpolation mode
- used internally will actually be trilinear. However, when the input is 4-D,
- the interpolation mode will legitimately be bilinear.
- padding_mode (str): padding mode for outside grid values
- ``'zeros'`` | ``'border'`` | ``'reflection'``. Default: ``'zeros'``
- align_corners (bool, optional): Geometrically, we consider the pixels of the
- input as squares rather than points.
- If set to ``True``, the extrema (``-1`` and ``1``) are considered as referring
- to the center points of the input's corner pixels. If set to ``False``, they
- are instead considered as referring to the corner points of the input's corner
- pixels, making the sampling more resolution agnostic.
- This option parallels the ``align_corners`` option in
- :func:`interpolate`, and so whichever option is used here
- should also be used there to resize the input image before grid sampling.
- Default: ``False``
- Returns:
- output (Tensor): output Tensor
- .. _`Spatial Transformer Networks`:
- https://arxiv.org/abs/1506.02025
- .. warning::
- When ``align_corners = True``, the grid positions depend on the pixel
- size relative to the input image size, and so the locations sampled by
- :func:`grid_sample` will differ for the same input given at different
- resolutions (that is, after being upsampled or downsampled).
- The default behavior up to version 1.2.0 was ``align_corners = True``.
- Since then, the default behavior has been changed to ``align_corners = False``,
- in order to bring it in line with the default for :func:`interpolate`.
- .. note::
- ``mode='bicubic'`` is implemented using the `cubic convolution algorithm`_ with :math:`\alpha=-0.75`.
- The constant :math:`\alpha` might be different from packages to packages.
- For example, `PIL`_ and `OpenCV`_ use -0.5 and -0.75 respectively.
- This algorithm may "overshoot" the range of values it's interpolating.
- For example, it may produce negative values or values greater than 255 when interpolating input in [0, 255].
- Clamp the results with :func:`torch.clamp` to ensure they are within the valid range.
- .. _`cubic convolution algorithm`: https://en.wikipedia.org/wiki/Bicubic_interpolation
- .. _`PIL`: https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/src/libImaging/Resample.c#L51
- .. _`OpenCV`: https://github.com/opencv/opencv/blob/f345ed564a06178670750bad59526cfa4033be55/modules/imgproc/src/resize.cpp#L908
- """
- if has_torch_function_variadic(input, grid):
- return handle_torch_function(
- grid_sample,
- (input, grid),
- input,
- grid,
- mode=mode,
- padding_mode=padding_mode,
- align_corners=align_corners,
- )
- if mode != "bilinear" and mode != "nearest" and mode != "bicubic":
- raise ValueError(
- f"nn.functional.grid_sample(): expected mode to be 'bilinear', 'nearest' or 'bicubic', but got: '{mode}'"
- )
- if (
- padding_mode != "zeros"
- and padding_mode != "border"
- and padding_mode != "reflection"
- ):
- raise ValueError(
- "nn.functional.grid_sample(): expected padding_mode "
- "to be 'zeros', 'border', or 'reflection', "
- f"but got: '{padding_mode}'"
- )
- if mode == "bilinear":
- mode_enum = 0
- elif mode == "nearest":
- mode_enum = 1
- else: # mode == 'bicubic'
- mode_enum = 2
- if padding_mode == "zeros":
- padding_mode_enum = 0
- elif padding_mode == "border":
- padding_mode_enum = 1
- else: # padding_mode == 'reflection'
- padding_mode_enum = 2
- if align_corners is None:
- warnings.warn(
- "Default grid_sample and affine_grid behavior has changed "
- "to align_corners=False since 1.3.0. Please specify "
- "align_corners=True if the old behavior is desired. "
- "See the documentation of grid_sample for details."
- )
- align_corners = False
- return torch.grid_sampler(input, grid, mode_enum, padding_mode_enum, align_corners)
- def affine_grid(
- theta: Tensor,
- size: list[int],
- align_corners: Optional[bool] = None,
- ) -> Tensor:
- r"""Generate 2D or 3D flow field (sampling grid), given a batch of affine matrices :attr:`theta`.
- .. note::
- This function is often used in conjunction with :func:`grid_sample`
- to build `Spatial Transformer Networks`_ .
- Args:
- theta (Tensor): input batch of affine matrices with shape
- (:math:`N \times 2 \times 3`) for 2D or
- (:math:`N \times 3 \times 4`) for 3D
- size (torch.Size): the target output image size.
- (:math:`N \times C \times H \times W` for 2D or
- :math:`N \times C \times D \times H \times W` for 3D)
- Example: torch.Size((32, 3, 24, 24))
- align_corners (bool, optional): if ``True``, consider ``-1`` and ``1``
- to refer to the centers of the corner pixels rather than the image corners.
- Refer to :func:`grid_sample` for a more complete description.
- A grid generated by :func:`affine_grid` should be passed to :func:`grid_sample`
- with the same setting for this option.
- Default: ``False``
- Returns:
- output (Tensor): output Tensor of size (:math:`N \times H \times W \times 2`)
- .. _`Spatial Transformer Networks`:
- https://arxiv.org/abs/1506.02025
- .. warning::
- When ``align_corners = True``, the grid positions depend on the pixel
- size relative to the input image size, and so the locations sampled by
- :func:`grid_sample` will differ for the same input given at different
- resolutions (that is, after being upsampled or downsampled).
- The default behavior up to version 1.2.0 was ``align_corners = True``.
- Since then, the default behavior has been changed to ``align_corners = False``,
- in order to bring it in line with the default for :func:`interpolate`.
- .. warning::
- When ``align_corners = True``, 2D affine transforms on 1D data and
- 3D affine transforms on 2D data (that is, when one of the spatial
- dimensions has unit size) are ill-defined, and not an intended use case.
- This is not a problem when ``align_corners = False``.
- Up to version 1.2.0, all grid points along a unit dimension were
- considered arbitrarily to be at ``-1``.
- From version 1.3.0, under ``align_corners = True`` all grid points
- along a unit dimension are considered to be at ``0``
- (the center of the input image).
- """
- if has_torch_function_unary(theta):
- return handle_torch_function(
- affine_grid, (theta,), theta, size, align_corners=align_corners
- )
- if align_corners is None:
- warnings.warn(
- "Default grid_sample and affine_grid behavior has changed "
- "to align_corners=False since 1.3.0. Please specify "
- "align_corners=True if the old behavior is desired. "
- "See the documentation of grid_sample for details."
- )
- align_corners = False
- # enforce floating point dtype on theta
- if not theta.is_floating_point():
- raise ValueError(
- f"Expected theta to have floating point type, but got {theta.dtype}"
- )
- # check that shapes and sizes match
- if len(size) == 4:
- if theta.dim() != 3 or theta.shape[-2] != 2 or theta.shape[-1] != 3:
- raise ValueError(
- f"Expected a batch of 2D affine matrices of shape Nx2x3 for size {size}. Got {theta.shape}."
- )
- spatial_size = size[-2:] # spatial dimension sizes
- elif len(size) == 5:
- if theta.dim() != 3 or theta.shape[-2] != 3 or theta.shape[-1] != 4:
- raise ValueError(
- f"Expected a batch of 3D affine matrices of shape Nx3x4 for size {size}. Got {theta.shape}."
- )
- spatial_size = size[-3:] # spatial dimension sizes
- else:
- raise NotImplementedError(
- "affine_grid only supports 4D and 5D sizes, "
- "for 2D and 3D affine transforms, respectively. "
- f"Got size {size}."
- )
- # check for empty span
- if align_corners and min(spatial_size) == 1:
- warnings.warn(
- "Since version 1.3.0, affine_grid behavior has changed "
- "for unit-size grids when align_corners=True. "
- "This is not an intended use case of affine_grid. "
- "See the documentation of affine_grid for details."
- )
- elif min(size) <= 0:
- raise ValueError(f"Expected non-zero, positive output size. Got {size}")
- return torch.affine_grid_generator(theta, size, align_corners)
- def pad(
- input: Tensor,
- pad: list[int],
- mode: str = "constant",
- value: Optional[float] = None,
- ) -> Tensor:
- r"""
- pad(input, pad, mode="constant", value=None) -> Tensor
- Pads tensor.
- Padding size:
- The padding size by which to pad some dimensions of :attr:`input`
- are described starting from the last dimension and moving forward.
- :math:`\left\lfloor\frac{\text{len(pad)}}{2}\right\rfloor` dimensions
- of ``input`` will be padded.
- For example, to pad only the last dimension of the input tensor, then
- :attr:`pad` has the form
- :math:`(\text{padding\_left}, \text{padding\_right})`;
- to pad the last 2 dimensions of the input tensor, then use
- :math:`(\text{padding\_left}, \text{padding\_right},`
- :math:`\text{padding\_top}, \text{padding\_bottom})`;
- to pad the last 3 dimensions, use
- :math:`(\text{padding\_left}, \text{padding\_right},`
- :math:`\text{padding\_top}, \text{padding\_bottom}`
- :math:`\text{padding\_front}, \text{padding\_back})`.
- Padding mode:
- See :class:`torch.nn.CircularPad2d`, :class:`torch.nn.ConstantPad2d`,
- :class:`torch.nn.ReflectionPad2d`, and :class:`torch.nn.ReplicationPad2d`
- for concrete examples on how each of the padding modes works. Constant
- padding is implemented for arbitrary dimensions. Circular, replicate and
- reflection padding are implemented for padding the last 3 dimensions of a
- 4D or 5D input tensor, the last 2 dimensions of a 3D or 4D input tensor,
- or the last dimension of a 2D or 3D input tensor.
- Note:
- When using the CUDA backend, this operation may induce nondeterministic
- behaviour in its backward pass that is not easily switched off.
- Please see the notes on :doc:`/notes/randomness` for background.
- Args:
- input (Tensor): N-dimensional tensor
- pad (tuple): m-elements tuple, where
- :math:`\frac{m}{2} \leq` input dimensions and :math:`m` is even.
- mode: ``'constant'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
- Default: ``'constant'``
- value: fill value for ``'constant'`` padding. Default: ``0``
- Examples::
- >>> t4d = torch.empty(3, 3, 4, 2)
- >>> p1d = (1, 1) # pad last dim by 1 on each side
- >>> out = F.pad(t4d, p1d, "constant", 0) # effectively zero padding
- >>> print(out.size())
- torch.Size([3, 3, 4, 4])
- >>> p2d = (1, 1, 2, 2) # pad last dim by (1, 1) and 2nd to last by (2, 2)
- >>> out = F.pad(t4d, p2d, "constant", 0)
- >>> print(out.size())
- torch.Size([3, 3, 8, 4])
- >>> t4d = torch.empty(3, 3, 4, 2)
- >>> p3d = (0, 1, 2, 1, 3, 3) # pad by (0, 1), (2, 1), and (3, 3)
- >>> out = F.pad(t4d, p3d, "constant", 0)
- >>> print(out.size())
- torch.Size([3, 9, 7, 3])
- """
- if has_torch_function_unary(input):
- return handle_torch_function(
- torch.nn.functional.pad, (input,), input, pad, mode=mode, value=value
- )
- if not torch.jit.is_scripting():
- if torch.are_deterministic_algorithms_enabled() and (
- input.is_cuda or input.is_xpu
- ):
- if mode == "replicate":
- # Use slow decomp whose backward will be in terms of index_put.
- # importlib is required because the import cannot be top level
- # (cycle) and cannot be nested (TS doesn't support)
- return importlib.import_module(
- "torch._decomp.decompositions"
- )._replication_pad(input, pad)
- return torch._C._nn.pad(input, pad, mode, value)
- # TODO: Fix via https://github.com/pytorch/pytorch/issues/75798
- pad.__module__ = "torch.nn.functional"
- # distance
- pairwise_distance = _add_docstr(
- torch.pairwise_distance,
- r"""
- pairwise_distance(x1, x2, p=2.0, eps=1e-6, keepdim=False) -> Tensor
- See :class:`torch.nn.PairwiseDistance` for details
- """,
- )
- pdist = _add_docstr(
- torch.pdist,
- r"""
- pdist(input, p=2) -> Tensor
- Computes the p-norm distance between every pair of row vectors in the input.
- This is identical to the upper triangular portion, excluding the diagonal, of
- `torch.norm(input[:, None] - input, dim=2, p=p)`. This function will be faster
- if the rows are contiguous.
- If input has shape :math:`N \times M` then the output will have shape
- :math:`\frac{1}{2} N (N - 1)`.
- This function is equivalent to ``scipy.spatial.distance.pdist(input,
- 'minkowski', p=p)`` if :math:`p \in (0, \infty)`. When :math:`p = 0` it is
- equivalent to ``scipy.spatial.distance.pdist(input, 'hamming') * M``.
- When :math:`p = \infty`, the closest scipy function is
- ``scipy.spatial.distance.pdist(xn, lambda x, y: np.abs(x - y).max())``.
- Args:
- input: input tensor of shape :math:`N \times M`.
- p: p value for the p-norm distance to calculate between each vector pair
- :math:`\in [0, \infty]`.
- """,
- )
- cosine_similarity = _add_docstr(
- torch.cosine_similarity,
- r"""
- cosine_similarity(x1, x2, dim=1, eps=1e-8) -> Tensor
- Returns cosine similarity between ``x1`` and ``x2``, computed along dim. ``x1`` and ``x2`` must be broadcastable
- to a common shape. ``dim`` refers to the dimension in this common shape. Dimension ``dim`` of the output is
- squeezed (see :func:`torch.squeeze`), resulting in the
- output tensor having 1 fewer dimension.
- .. math ::
- \text{similarity} = \dfrac{x_1 \cdot x_2}{\max(\Vert x_1 \Vert _2, \epsilon) \cdot \max(\Vert x_2 \Vert _2, \epsilon)}
- Supports :ref:`type promotion <type-promotion-doc>`.
- Args:
- x1 (Tensor): First input.
- x2 (Tensor): Second input.
- dim (int, optional): Dimension along which cosine similarity is computed. Default: 1
- eps (float, optional): Small value to avoid division by zero.
- Default: 1e-8
- Example::
- >>> input1 = torch.randn(100, 128)
- >>> input2 = torch.randn(100, 128)
- >>> output = F.cosine_similarity(input1, input2)
- >>> print(output)
- """,
- )
- one_hot = _add_docstr(
- torch._C._nn.one_hot,
- r"""
- one_hot(tensor, num_classes=-1) -> LongTensor
- Takes LongTensor with index values of shape ``(*)`` and returns a tensor
- of shape ``(*, num_classes)`` that have zeros everywhere except where the
- index of last dimension matches the corresponding value of the input tensor,
- in which case it will be 1.
- See also `One-hot on Wikipedia`_ .
- .. _One-hot on Wikipedia:
- https://en.wikipedia.org/wiki/One-hot
- Arguments:
- tensor (LongTensor): class values of any shape.
- num_classes (int, optional): Total number of classes. If set to -1, the number
- of classes will be inferred as one greater than the largest class
- value in the input tensor. Default: -1
- Returns:
- LongTensor that has one more dimension with 1 values at the
- index of last dimension indicated by the input, and 0 everywhere
- else.
- Examples:
- >>> F.one_hot(torch.arange(0, 5) % 3)
- tensor([[1, 0, 0],
- [0, 1, 0],
- [0, 0, 1],
- [1, 0, 0],
- [0, 1, 0]])
- >>> F.one_hot(torch.arange(0, 5) % 3, num_classes=5)
- tensor([[1, 0, 0, 0, 0],
- [0, 1, 0, 0, 0],
- [0, 0, 1, 0, 0],
- [1, 0, 0, 0, 0],
- [0, 1, 0, 0, 0]])
- >>> F.one_hot(torch.arange(0, 6).view(3,2) % 3)
- tensor([[[1, 0, 0],
- [0, 1, 0]],
- [[0, 0, 1],
- [1, 0, 0]],
- [[0, 1, 0],
- [0, 0, 1]]])
- """,
- )
- def triplet_margin_loss(
- anchor: Tensor,
- positive: Tensor,
- negative: Tensor,
- margin: float = 1.0,
- p: float = 2,
- eps: float = 1e-6,
- swap: bool = False,
- size_average: Optional[bool] = None,
- reduce: Optional[bool] = None,
- reduction: str = "mean",
- ) -> Tensor:
- r"""Compute the triplet loss between given input tensors and a margin greater than 0.
- See :class:`~torch.nn.TripletMarginLoss` for details.
- """
- if has_torch_function_variadic(anchor, positive, negative):
- return handle_torch_function(
- triplet_margin_loss,
- (anchor, positive, negative),
- anchor,
- positive,
- negative,
- margin=margin,
- p=p,
- eps=eps,
- swap=swap,
- size_average=size_average,
- reduce=reduce,
- reduction=reduction,
- )
- if size_average is not None or reduce is not None:
- reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
- else:
- reduction_enum = _Reduction.get_enum(reduction)
- if margin <= 0:
- raise ValueError(f"margin must be greater than 0, got {margin}")
- return torch.triplet_margin_loss(
- anchor, positive, negative, margin, p, eps, swap, reduction_enum
- )
- def triplet_margin_with_distance_loss(
- anchor: Tensor,
- positive: Tensor,
- negative: Tensor,
- *,
- distance_function: Optional[Callable[[Tensor, Tensor], Tensor]] = None,
- margin: float = 1.0,
- swap: bool = False,
- reduction: str = "mean",
- ) -> Tensor:
- r"""Compute the triplet margin loss for input tensors using a custom distance function.
- See :class:`~torch.nn.TripletMarginWithDistanceLoss` for details.
- """
- if torch.jit.is_scripting():
- raise NotImplementedError(
- "F.triplet_margin_with_distance_loss does not support JIT scripting: "
- "functions requiring Callables cannot be scripted."
- )
- if has_torch_function_variadic(anchor, positive, negative):
- return handle_torch_function(
- triplet_margin_with_distance_loss,
- (anchor, positive, negative),
- anchor,
- positive,
- negative,
- distance_function=distance_function,
- margin=margin,
- swap=swap,
- reduction=reduction,
- )
- # Check validity of reduction mode
- if reduction not in ("mean", "sum", "none"):
- raise ValueError(f"{reduction} is not a valid value for reduction")
- # Check validity of margin
- if margin <= 0:
- raise ValueError(f"margin must be greater than 0, got {margin}")
- # Check dimensions
- a_dim = anchor.ndim
- p_dim = positive.ndim
- n_dim = negative.ndim
- if not (a_dim == p_dim and p_dim == n_dim):
- raise RuntimeError(
- f"The anchor, positive, and negative tensors are expected to have "
- f"the same number of dimensions, but got: anchor {a_dim}D, "
- f"positive {p_dim}D, and negative {n_dim}D inputs"
- )
- # Calculate loss
- if distance_function is None:
- distance_function = torch.pairwise_distance
- dist_pos = distance_function(anchor, positive)
- dist_neg = distance_function(anchor, negative)
- # The distance swap is described in the paper "Learning shallow
- # convolutional feature descriptors with triplet losses" by V. Balntas, E.
- # Riba et al. If True, and if the positive example is closer to the
- # negative example than the anchor is, swaps the positive example and the
- # anchor in the loss computation.
- if swap:
- dist_swap = distance_function(positive, negative)
- dist_neg = torch.minimum(dist_neg, dist_swap)
- loss = torch.clamp_min(margin + dist_pos - dist_neg, 0)
- # Apply reduction
- if reduction == "sum":
- return torch.sum(loss)
- elif reduction == "mean":
- return torch.mean(loss)
- else: # reduction == "none"
- return loss
- def normalize(
- input: Tensor,
- p: float = 2.0,
- dim: int = 1,
- eps: float = 1e-12,
- out: Optional[Tensor] = None,
- ) -> Tensor:
- r"""Perform :math:`L_p` normalization of inputs over specified dimension.
- For a tensor :attr:`input` of sizes :math:`(n_0, ..., n_{dim}, ..., n_k)`, each
- :math:`n_{dim}` -element vector :math:`v` along dimension :attr:`dim` is transformed as
- .. math::
- v = \frac{v}{\max(\lVert v \rVert_p, \epsilon)}.
- With the default arguments it uses the Euclidean norm over vectors along dimension :math:`1` for normalization.
- Args:
- input: input tensor of any shape
- p (float): the exponent value in the norm formulation. Default: 2
- dim (int or tuple of ints): the dimension to reduce. Default: 1
- eps (float): small value to avoid division by zero. Default: 1e-12
- out (Tensor, optional): the output tensor. If :attr:`out` is used, this
- operation won't be differentiable.
- """
- if has_torch_function_variadic(input, out):
- return handle_torch_function(
- normalize, (input, out), input, p=p, dim=dim, eps=eps, out=out
- )
- if out is None:
- denom = input.norm(p, dim, keepdim=True).clamp_min(eps).expand_as(input)
- return input / denom
- else:
- denom = input.norm(p, dim, keepdim=True).clamp_min_(eps).expand_as(input)
- return torch.div(input, denom, out=out)
- def assert_int_or_pair(arg: list[int], arg_name: str, message: str) -> None:
- assert isinstance(arg, int) or len(arg) == 2, message.format(arg_name)
- def unfold(
- input: Tensor,
- kernel_size: BroadcastingList2[int],
- dilation: BroadcastingList2[int] = 1,
- padding: BroadcastingList2[int] = 0,
- stride: BroadcastingList2[int] = 1,
- ) -> Tensor:
- r"""Extract sliding local blocks from a batched input tensor.
- .. warning::
- Currently, only 4-D input tensors (batched image-like tensors) are
- supported.
- .. warning::
- More than one element of the unfolded tensor may refer to a single
- memory location. As a result, in-place operations (especially ones that
- are vectorized) may result in incorrect behavior. If you need to write
- to the tensor, please clone it first.
- See :class:`torch.nn.Unfold` for details
- """
- if has_torch_function_unary(input):
- return handle_torch_function(
- unfold,
- (input,),
- input,
- kernel_size,
- dilation=dilation,
- padding=padding,
- stride=stride,
- )
- return torch._C._nn.im2col(
- input, _pair(kernel_size), _pair(dilation), _pair(padding), _pair(stride)
- )
- def fold(
- input: Tensor,
- output_size: BroadcastingList2[int],
- kernel_size: BroadcastingList2[int],
- dilation: BroadcastingList2[int] = 1,
- padding: BroadcastingList2[int] = 0,
- stride: BroadcastingList2[int] = 1,
- ) -> Tensor:
- r"""Combine an array of sliding local blocks into a large containing tensor.
- .. warning::
- Currently, only unbatched (3D) or batched (4D) image-like output tensors are supported.
- See :class:`torch.nn.Fold` for details
- """
- if has_torch_function_unary(input):
- return handle_torch_function(
- fold,
- (input,),
- input,
- output_size,
- kernel_size,
- dilation=dilation,
- padding=padding,
- stride=stride,
- )
- return torch._C._nn.col2im(
- input,
- _pair(output_size),
- _pair(kernel_size),
- _pair(dilation),
- _pair(padding),
- _pair(stride),
- )
- #
- # multihead attention
- #
- def _in_projection_packed(
- q: Tensor,
- k: Tensor,
- v: Tensor,
- w: Tensor,
- b: Optional[Tensor] = None,
- ) -> list[Tensor]:
- r"""Perform the in-projection step of the attention operation, using packed weights.
- Output is a triple containing projection tensors for query, key and value.
- Args:
- q, k, v: query, key and value tensors to be projected. For self-attention,
- these are typically the same tensor; for encoder-decoder attention,
- k and v are typically the same tensor. (We take advantage of these
- identities for performance if they are present.) Regardless, q, k and v
- must share a common embedding dimension; otherwise their shapes may vary.
- w: projection weights for q, k and v, packed into a single tensor. Weights
- are packed along dimension 0, in q, k, v order.
- b: optional projection biases for q, k and v, packed into a single tensor
- in q, k, v order.
- Shape:
- Inputs:
- - q: :math:`(..., E)` where E is the embedding dimension
- - k: :math:`(..., E)` where E is the embedding dimension
- - v: :math:`(..., E)` where E is the embedding dimension
- - w: :math:`(E * 3, E)` where E is the embedding dimension
- - b: :math:`E * 3` where E is the embedding dimension
- Output:
- - in output list :math:`[q', k', v']`, each output tensor will have the
- same shape as the corresponding input tensor.
- """
- E = q.size(-1)
- if k is v:
- if q is k:
- # self-attention
- proj = linear(q, w, b)
- # reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk()
- proj = (
- proj.unflatten(-1, (3, E))
- .unsqueeze(0)
- .transpose(0, -2)
- .squeeze(-2)
- .contiguous()
- )
- return proj[0], proj[1], proj[2]
- else:
- # encoder-decoder attention
- w_q, w_kv = w.split([E, E * 2])
- if b is None:
- b_q = b_kv = None
- else:
- b_q, b_kv = b.split([E, E * 2])
- q_proj = linear(q, w_q, b_q)
- kv_proj = linear(k, w_kv, b_kv)
- # reshape to 2, E and not E, 2 is deliberate for better memory coalescing and keeping same order as chunk()
- kv_proj = (
- kv_proj.unflatten(-1, (2, E))
- .unsqueeze(0)
- .transpose(0, -2)
- .squeeze(-2)
- .contiguous()
- )
- return (q_proj, kv_proj[0], kv_proj[1])
- else:
- w_q, w_k, w_v = w.chunk(3)
- if b is None:
- b_q = b_k = b_v = None
- else:
- b_q, b_k, b_v = b.chunk(3)
- return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)
- def _in_projection(
- q: Tensor,
- k: Tensor,
- v: Tensor,
- w_q: Tensor,
- w_k: Tensor,
- w_v: Tensor,
- b_q: Optional[Tensor] = None,
- b_k: Optional[Tensor] = None,
- b_v: Optional[Tensor] = None,
- ) -> tuple[Tensor, Tensor, Tensor]:
- r"""Perform the in-projection step of the attention operation.
- This is simply a triple of linear projections,
- with shape constraints on the weights which
- ensure embedding dimension uniformity in the projected outputs.
- Output is a triple containing projection tensors for query, key and value.
- Args:
- q, k, v: query, key and value tensors to be projected.
- w_q, w_k, w_v: weights for q, k and v, respectively.
- b_q, b_k, b_v: optional biases for q, k and v, respectively.
- Shape:
- Inputs:
- - q: :math:`(Qdims..., Eq)` where Eq is the query embedding dimension and Qdims are any
- number of leading dimensions.
- - k: :math:`(Kdims..., Ek)` where Ek is the key embedding dimension and Kdims are any
- number of leading dimensions.
- - v: :math:`(Vdims..., Ev)` where Ev is the value embedding dimension and Vdims are any
- number of leading dimensions.
- - w_q: :math:`(Eq, Eq)`
- - w_k: :math:`(Eq, Ek)`
- - w_v: :math:`(Eq, Ev)`
- - b_q: :math:`(Eq)`
- - b_k: :math:`(Eq)`
- - b_v: :math:`(Eq)`
- Output: in output triple :math:`(q', k', v')`,
- - q': :math:`[Qdims..., Eq]`
- - k': :math:`[Kdims..., Eq]`
- - v': :math:`[Vdims..., Eq]`
- """
- Eq, Ek, Ev = q.size(-1), k.size(-1), v.size(-1)
- assert w_q.shape == (
- Eq,
- Eq,
- ), f"expecting query weights shape of {(Eq, Eq)}, but got {w_q.shape}"
- assert w_k.shape == (
- Eq,
- Ek,
- ), f"expecting key weights shape of {(Eq, Ek)}, but got {w_k.shape}"
- assert w_v.shape == (
- Eq,
- Ev,
- ), f"expecting value weights shape of {(Eq, Ev)}, but got {w_v.shape}"
- assert b_q is None or b_q.shape == (Eq,), (
- f"expecting query bias shape of {(Eq,)}, but got {b_q.shape}"
- )
- assert b_k is None or b_k.shape == (Eq,), (
- f"expecting key bias shape of {(Eq,)}, but got {b_k.shape}"
- )
- assert b_v is None or b_v.shape == (Eq,), (
- f"expecting value bias shape of {(Eq,)}, but got {b_v.shape}"
- )
- return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)
- scaled_dot_product_attention = _add_docstr(
- torch._C._nn.scaled_dot_product_attention,
- r"""scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0,
- is_causal=False, scale=None, enable_gqa=False) -> Tensor:
- Computes scaled dot product attention on query, key and value tensors, using an optional attention mask if passed,
- and applying dropout if a probability greater than 0.0 is specified. The optional scale argument can only be
- specified as a keyword argument.
- .. code-block:: python
- # Efficient implementation equivalent to the following:
- def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0,
- is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:
- L, S = query.size(-2), key.size(-2)
- scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
- attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
- if is_causal:
- assert attn_mask is None
- temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
- attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
- if attn_mask is not None:
- if attn_mask.dtype == torch.bool:
- attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
- else:
- attn_bias = attn_mask + attn_bias
- if enable_gqa:
- key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
- value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)
- attn_weight = query @ key.transpose(-2, -1) * scale_factor
- attn_weight += attn_bias
- attn_weight = torch.softmax(attn_weight, dim=-1)
- attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
- return attn_weight @ value
- .. warning::
- This function is beta and subject to change.
- .. warning::
- This function always applies dropout according to the specified ``dropout_p`` argument.
- To disable dropout during evaluation, be sure to pass a value of ``0.0`` when the module
- that makes the function call is not in training mode.
- For example:
- .. code-block:: python
- class MyModel(nn.Module):
- def __init__(self, p=0.5):
- super().__init__()
- self.p = p
- def forward(self, ...):
- return F.scaled_dot_product_attention(...,
- dropout_p=(self.p if self.training else 0.0))
- Note:
- There are currently three supported implementations of scaled dot product attention:
- - `FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning`_
- - `Memory-Efficient Attention`_
- - A PyTorch implementation defined in C++ matching the above formulation
- The function may call optimized kernels for improved performance when using the CUDA backend.
- For all other backends, the PyTorch implementation will be used.
- All implementations are enabled by default. Scaled dot product attention attempts to automatically select the
- most optimal implementation based on the inputs. In order to provide more fine-grained control over what implementation
- is used, the following functions are provided for enabling and disabling implementations.
- The context manager is the preferred mechanism:
- - :func:`torch.nn.attention.sdpa_kernel`: A context manager used to enable or disable any of the implementations.
- - :func:`torch.backends.cuda.enable_flash_sdp`: Globally enables or disables FlashAttention.
- - :func:`torch.backends.cuda.enable_mem_efficient_sdp`: Globally enables or disables Memory-Efficient Attention.
- - :func:`torch.backends.cuda.enable_math_sdp`: Globally enables or disables the PyTorch C++ implementation.
- Each of the fused kernels has specific input limitations. If the user requires the use of a specific fused implementation,
- disable the PyTorch C++ implementation using :func:`torch.nn.attention.sdpa_kernel`.
- In the event that a fused implementation is not available, a warning will be raised with the
- reasons why the fused implementation cannot run.
- Due to the nature of fusing floating point operations, the output of this function may be different
- depending on what backend kernel is chosen.
- The c++ implementation supports torch.float64 and can be used when higher precision is required.
- For math backend, all intermediates are kept in torch.float if inputs are in torch.half or torch.bfloat16.
- For more information please see :doc:`/notes/numerical_accuracy`
- Grouped Query Attention (GQA) is an experimental feature. It currently works only for Flash_attention
- and math kernel on CUDA tensor, and does not support Nested tensor.
- Constraints for GQA:
- - number_of_heads_query % number_of_heads_key_value == 0 and,
- - number_of_heads_key == number_of_heads_value
- Note:
- {cudnn_reproducibility_note}
- """.format(**reproducibility_notes)
- + r"""
- Args:
- query (Tensor): Query tensor; shape :math:`(N, ..., Hq, L, E)`.
- key (Tensor): Key tensor; shape :math:`(N, ..., H, S, E)`.
- value (Tensor): Value tensor; shape :math:`(N, ..., H, S, Ev)`.
- attn_mask (optional Tensor): Attention mask; shape must be broadcastable to the shape of attention weights,
- which is :math:`(N,..., L, S)`. Two types of masks are supported.
- A boolean mask where a value of True indicates that the element *should* take part in attention.
- A float mask of the same type as query, key, value that is added to the attention score.
- dropout_p (float): Dropout probability; if greater than 0.0, dropout is applied
- is_causal (bool): If set to true, the attention masking is a lower triangular matrix when the mask is a
- square matrix. The attention masking has the form of the upper left causal bias due to the alignment
- (see :class:`torch.nn.attention.bias.CausalBias`) when the mask is a non-square matrix.
- An error is thrown if both attn_mask and is_causal are set.
- scale (optional float, keyword-only): Scaling factor applied prior to softmax. If None, the default value is set
- to :math:`\frac{1}{\sqrt{E}}`.
- enable_gqa (bool): If set to True, Grouped Query Attention (GQA) is enabled, by default it is set to False.
- Returns:
- output (Tensor): Attention output; shape :math:`(N, ..., Hq, L, Ev)`.
- Shape legend:
- - :math:`N: \text{Batch size} ... : \text{Any number of other batch dimensions (optional)}`
- - :math:`S: \text{Source sequence length}`
- - :math:`L: \text{Target sequence length}`
- - :math:`E: \text{Embedding dimension of the query and key}`
- - :math:`Ev: \text{Embedding dimension of the value}`
- - :math:`Hq: \text{Number of heads of query}`
- - :math:`H: \text{Number of heads of key and value}`
- Examples:
- >>> # Optionally use the context manager to ensure one of the fused kernels is run
- >>> query = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
- >>> key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
- >>> value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
- >>> with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
- >>> F.scaled_dot_product_attention(query,key,value)
- >>> # Sample for GQA for llama3
- >>> query = torch.rand(32, 32, 128, 64, dtype=torch.float16, device="cuda")
- >>> key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
- >>> value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
- >>> with sdpa_kernel(backends=[SDPBackend.MATH]):
- >>> F.scaled_dot_product_attention(query,key,value,enable_gqa=True)
- .. _FlashAttention-2\: Faster Attention with Better Parallelism and Work Partitioning:
- https://arxiv.org/abs/2307.08691
- .. _Memory-Efficient Attention:
- https://github.com/facebookresearch/xformers
- .. _Grouped-Query Attention:
- https://arxiv.org/pdf/2305.13245
- """,
- )
- def _mha_shape_check(
- query: Tensor,
- key: Tensor,
- value: Tensor,
- key_padding_mask: Optional[Tensor],
- attn_mask: Optional[Tensor],
- num_heads: int,
- ):
- # Verifies the expected shape for `query, `key`, `value`, `key_padding_mask` and `attn_mask`
- # and returns if the input is batched or not.
- # Raises an error if `query` is not 2-D (unbatched) or 3-D (batched) tensor.
- # Shape check.
- if query.dim() == 3:
- # Batched Inputs
- is_batched = True
- assert key.dim() == 3 and value.dim() == 3, (
- "For batched (3-D) `query`, expected `key` and `value` to be 3-D"
- f" but found {key.dim()}-D and {value.dim()}-D tensors respectively"
- )
- if key_padding_mask is not None:
- assert key_padding_mask.dim() == 2, (
- "For batched (3-D) `query`, expected `key_padding_mask` to be `None` or 2-D"
- f" but found {key_padding_mask.dim()}-D tensor instead"
- )
- if attn_mask is not None:
- assert attn_mask.dim() in (2, 3), (
- "For batched (3-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D"
- f" but found {attn_mask.dim()}-D tensor instead"
- )
- elif query.dim() == 2:
- # Unbatched Inputs
- is_batched = False
- assert key.dim() == 2 and value.dim() == 2, (
- "For unbatched (2-D) `query`, expected `key` and `value` to be 2-D"
- f" but found {key.dim()}-D and {value.dim()}-D tensors respectively"
- )
- if key_padding_mask is not None:
- assert key_padding_mask.dim() == 1, (
- "For unbatched (2-D) `query`, expected `key_padding_mask` to be `None` or 1-D"
- f" but found {key_padding_mask.dim()}-D tensor instead"
- )
- if attn_mask is not None:
- assert attn_mask.dim() in (2, 3), (
- "For unbatched (2-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D"
- f" but found {attn_mask.dim()}-D tensor instead"
- )
- if attn_mask.dim() == 3:
- expected_shape = (num_heads, query.shape[0], key.shape[0])
- assert attn_mask.shape == expected_shape, (
- f"Expected `attn_mask` shape to be {expected_shape} but got {attn_mask.shape}"
- )
- else:
- raise AssertionError(
- f"query should be unbatched 2D or batched 3D tensor but received {query.dim()}-D query tensor"
- )
- return is_batched
- def _canonical_mask(
- mask: Optional[Tensor],
- mask_name: str,
- other_type: Optional[DType],
- other_name: str,
- target_type: DType,
- check_other: bool = True,
- ) -> Optional[Tensor]:
- if mask is not None:
- _mask_dtype = mask.dtype
- _mask_is_float = torch.is_floating_point(mask)
- if _mask_dtype != torch.bool and not _mask_is_float:
- raise AssertionError(
- f"only bool and floating types of {mask_name} are supported"
- )
- if check_other and other_type is not None:
- if _mask_dtype != other_type:
- warnings.warn(
- f"Support for mismatched {mask_name} and {other_name} "
- "is deprecated. Use same type for both instead."
- )
- if not _mask_is_float:
- mask = torch.zeros_like(mask, dtype=target_type).masked_fill_(
- mask, float("-inf")
- )
- return mask
- def _none_or_dtype(input: Optional[Tensor]) -> Optional[DType]:
- if input is None:
- return None
- elif isinstance(input, torch.Tensor):
- return input.dtype
- raise RuntimeError("input to _none_or_dtype() must be None or torch.Tensor")
- def _check_key_padding_mask(
- key_padding_mask: torch.Tensor, src_len: int, bsz: int
- ) -> None:
- torch._check_with(
- AssertionError,
- key_padding_mask.shape[0] == bsz,
- lambda: f"Expected key_padded_mask.shape[0] to be {bsz}, but got {key_padding_mask.shape[0]}",
- )
- torch._check_with(
- AssertionError,
- key_padding_mask.shape[1] == src_len,
- lambda: f"Expected key_padded_mask.shape[1] to be {src_len}, but got {key_padding_mask.shape[1]}",
- )
- def multi_head_attention_forward(
- query: Tensor,
- key: Tensor,
- value: Tensor,
- embed_dim_to_check: int,
- num_heads: int,
- in_proj_weight: Optional[Tensor],
- in_proj_bias: Optional[Tensor],
- bias_k: Optional[Tensor],
- bias_v: Optional[Tensor],
- add_zero_attn: bool,
- dropout_p: float,
- out_proj_weight: Tensor,
- out_proj_bias: Optional[Tensor],
- training: bool = True,
- key_padding_mask: Optional[Tensor] = None,
- need_weights: bool = True,
- attn_mask: Optional[Tensor] = None,
- use_separate_proj_weight: bool = False,
- q_proj_weight: Optional[Tensor] = None,
- k_proj_weight: Optional[Tensor] = None,
- v_proj_weight: Optional[Tensor] = None,
- static_k: Optional[Tensor] = None,
- static_v: Optional[Tensor] = None,
- average_attn_weights: bool = True,
- is_causal: bool = False,
- ) -> tuple[Tensor, Optional[Tensor]]:
- r"""Forward method for MultiHeadAttention.
- See :class:`torch.nn.MultiheadAttention` for details.
- Args:
- query, key, value: map a query and a set of key-value pairs to an output.
- See "Attention Is All You Need" for more details.
- embed_dim_to_check: total dimension of the model.
- num_heads: parallel attention heads.
- in_proj_weight, in_proj_bias: input projection weight and bias.
- bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
- add_zero_attn: add a new batch of zeros to the key and
- value sequences at dim=1.
- dropout_p: probability of an element to be zeroed.
- out_proj_weight, out_proj_bias: the output projection weight and bias.
- training: apply dropout if is ``True``.
- key_padding_mask: if provided, specified padding elements in the key will
- be ignored by the attention. This is an binary mask. When the value is True,
- the corresponding value on the attention layer will be filled with -inf.
- need_weights: output attn_output_weights.
- Default: `True`
- Note: `needs_weight` defaults to `True`, but should be set to `False`
- For best performance when attention weights are not needed.
- *Setting needs_weights to `True`
- leads to a significant performance degradation.*
- attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
- the batches while a 3D mask allows to specify a different mask for the entries of each batch.
- is_causal: If specified, applies a causal mask as attention mask, and ignores
- attn_mask for computing scaled dot product attention.
- Default: ``False``.
- .. warning::
- is_causal is provides a hint that the attn_mask is the
- causal mask.Providing incorrect hints can result in
- incorrect execution, including forward and backward
- compatibility.
- use_separate_proj_weight: the function accept the proj. weights for query, key,
- and value in different forms. If false, in_proj_weight will be used, which is
- a combination of q_proj_weight, k_proj_weight, v_proj_weight.
- q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
- static_k, static_v: static key and value used for attention operators.
- average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across heads.
- Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an effect
- when ``need_weights=True.``. Default: True
- Shape:
- Inputs:
- - query: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
- the embedding dimension.
- - key: :math:`(S, E)` or :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
- the embedding dimension.
- - value: :math:`(S, E)` or :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
- the embedding dimension.
- - key_padding_mask: :math:`(S)` or :math:`(N, S)` where N is the batch size, S is the source sequence length.
- If a FloatTensor is provided, it will be directly added to the value.
- If a BoolTensor is provided, the positions with the
- value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
- - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
- 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
- S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
- positions. If a BoolTensor is provided, positions with ``True``
- are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
- is provided, it will be added to the attention weight.
- - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
- N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
- - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
- N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
- Outputs:
- - attn_output: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
- E is the embedding dimension.
- - attn_output_weights: Only returned when ``need_weights=True``. If ``average_attn_weights=True``, returns
- attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
- :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
- :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
- head of shape :math:`(num_heads, L, S)` when input is unbatched or :math:`(N, num_heads, L, S)`.
- """
- tens_ops = (
- query,
- key,
- value,
- in_proj_weight,
- in_proj_bias,
- bias_k,
- bias_v,
- out_proj_weight,
- out_proj_bias,
- )
- if has_torch_function(tens_ops):
- return handle_torch_function(
- multi_head_attention_forward,
- tens_ops,
- query,
- key,
- value,
- embed_dim_to_check,
- num_heads,
- in_proj_weight,
- in_proj_bias,
- bias_k,
- bias_v,
- add_zero_attn,
- dropout_p,
- out_proj_weight,
- out_proj_bias,
- training=training,
- key_padding_mask=key_padding_mask,
- need_weights=need_weights,
- attn_mask=attn_mask,
- is_causal=is_causal,
- use_separate_proj_weight=use_separate_proj_weight,
- q_proj_weight=q_proj_weight,
- k_proj_weight=k_proj_weight,
- v_proj_weight=v_proj_weight,
- static_k=static_k,
- static_v=static_v,
- average_attn_weights=average_attn_weights,
- )
- is_batched = _mha_shape_check(
- query, key, value, key_padding_mask, attn_mask, num_heads
- )
- # For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input
- # is batched, run the computation and before returning squeeze the
- # batch dimension so that the output doesn't carry this temporary batch dimension.
- if not is_batched:
- # unsqueeze if the input is unbatched
- query = query.unsqueeze(1)
- key = key.unsqueeze(1)
- value = value.unsqueeze(1)
- if key_padding_mask is not None:
- key_padding_mask = key_padding_mask.unsqueeze(0)
- # set up shape vars
- tgt_len, bsz, embed_dim = query.shape
- src_len, _, _ = key.shape
- key_padding_mask = _canonical_mask(
- mask=key_padding_mask,
- mask_name="key_padding_mask",
- other_type=_none_or_dtype(attn_mask),
- other_name="attn_mask",
- target_type=query.dtype,
- )
- if is_causal and attn_mask is None:
- raise RuntimeError(
- "Need attn_mask if specifying the is_causal hint. "
- "You may use the Transformer module method "
- "`generate_square_subsequent_mask` to create this mask."
- )
- if is_causal and key_padding_mask is None and not need_weights:
- # when we have a kpm or need weights, we need attn_mask
- # Otherwise, we use the is_causal hint go as is_causal
- # indicator to SDPA.
- attn_mask = None
- else:
- attn_mask = _canonical_mask(
- mask=attn_mask,
- mask_name="attn_mask",
- other_type=None,
- other_name="",
- target_type=query.dtype,
- check_other=False,
- )
- if key_padding_mask is not None:
- # We have the attn_mask, and use that to merge kpm into it.
- # Turn off use of is_causal hint, as the merged mask is no
- # longer causal.
- is_causal = False
- assert embed_dim == embed_dim_to_check, (
- f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
- )
- if isinstance(embed_dim, torch.Tensor):
- # embed_dim can be a tensor when JIT tracing
- head_dim = embed_dim.div(num_heads, rounding_mode="trunc")
- else:
- head_dim = embed_dim // num_heads
- assert head_dim * num_heads == embed_dim, (
- f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
- )
- if use_separate_proj_weight:
- # allow MHA to have different embedding dimensions when separate projection weights are used
- assert key.shape[:2] == value.shape[:2], (
- f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
- )
- else:
- assert key.shape == value.shape, (
- f"key shape {key.shape} does not match value shape {value.shape}"
- )
- #
- # compute in-projection
- #
- if not use_separate_proj_weight:
- assert in_proj_weight is not None, (
- "use_separate_proj_weight is False but in_proj_weight is None"
- )
- q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
- else:
- assert q_proj_weight is not None, (
- "use_separate_proj_weight is True but q_proj_weight is None"
- )
- assert k_proj_weight is not None, (
- "use_separate_proj_weight is True but k_proj_weight is None"
- )
- assert v_proj_weight is not None, (
- "use_separate_proj_weight is True but v_proj_weight is None"
- )
- if in_proj_bias is None:
- b_q = b_k = b_v = None
- else:
- b_q, b_k, b_v = in_proj_bias.chunk(3)
- q, k, v = _in_projection(
- query,
- key,
- value,
- q_proj_weight,
- k_proj_weight,
- v_proj_weight,
- b_q,
- b_k,
- b_v,
- )
- # prep attention mask
- if attn_mask is not None:
- # ensure attn_mask's dim is 3
- if attn_mask.dim() == 2:
- correct_2d_size = (tgt_len, src_len)
- if attn_mask.shape != correct_2d_size:
- raise RuntimeError(
- f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}."
- )
- attn_mask = attn_mask.unsqueeze(0)
- elif attn_mask.dim() == 3:
- correct_3d_size = (bsz * num_heads, tgt_len, src_len)
- if attn_mask.shape != correct_3d_size:
- raise RuntimeError(
- f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}."
- )
- else:
- raise RuntimeError(
- f"attn_mask's dimension {attn_mask.dim()} is not supported"
- )
- # add bias along batch dimension (currently second)
- if bias_k is not None and bias_v is not None:
- assert static_k is None, "bias cannot be added to static key."
- assert static_v is None, "bias cannot be added to static value."
- k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
- v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
- if attn_mask is not None:
- attn_mask = pad(attn_mask, (0, 1))
- if key_padding_mask is not None:
- key_padding_mask = pad(key_padding_mask, (0, 1))
- else:
- assert bias_k is None
- assert bias_v is None
- #
- # reshape q, k, v for multihead attention and make them batch first
- #
- q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
- if static_k is None:
- k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
- else:
- # TODO finish disentangling control flow so we don't do in-projections when statics are passed
- assert static_k.size(0) == bsz * num_heads, (
- f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
- )
- assert static_k.size(2) == head_dim, (
- f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
- )
- k = static_k
- if static_v is None:
- v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
- else:
- # TODO finish disentangling control flow so we don't do in-projections when statics are passed
- assert static_v.size(0) == bsz * num_heads, (
- f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
- )
- assert static_v.size(2) == head_dim, (
- f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
- )
- v = static_v
- # add zero attention along batch dimension (now first)
- if add_zero_attn:
- zero_attn_shape = (bsz * num_heads, 1, head_dim)
- k = torch.cat(
- [k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1
- )
- v = torch.cat(
- [v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1
- )
- if attn_mask is not None:
- attn_mask = pad(attn_mask, (0, 1))
- if key_padding_mask is not None:
- key_padding_mask = pad(key_padding_mask, (0, 1))
- # update source sequence length after adjustments
- src_len = k.size(1)
- # merge key padding and attention masks
- if key_padding_mask is not None:
- if not torch.jit.is_scripting() and not torch.jit.is_tracing():
- _check_key_padding_mask(key_padding_mask, src_len, bsz)
- key_padding_mask = (
- key_padding_mask.view(bsz, 1, 1, src_len)
- .expand(-1, num_heads, -1, -1)
- .reshape(bsz * num_heads, 1, src_len)
- )
- if attn_mask is None:
- attn_mask = key_padding_mask
- else:
- attn_mask = attn_mask + key_padding_mask
- # adjust dropout probability
- if not training:
- dropout_p = 0.0
- #
- # (deep breath) calculate attention and out projection
- #
- if need_weights:
- _B, _Nt, E = q.shape
- q_scaled = q * math.sqrt(1.0 / float(E))
- assert not (is_causal and attn_mask is None), (
- "FIXME: is_causal not implemented for need_weights"
- )
- if attn_mask is not None:
- attn_output_weights = torch.baddbmm(
- attn_mask, q_scaled, k.transpose(-2, -1)
- )
- else:
- attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
- attn_output_weights = softmax(attn_output_weights, dim=-1)
- if dropout_p > 0.0:
- attn_output_weights = dropout(attn_output_weights, p=dropout_p)
- attn_output = torch.bmm(attn_output_weights, v)
- attn_output = (
- attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
- )
- attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
- attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
- # optionally average attention weights over heads
- attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
- if average_attn_weights:
- attn_output_weights = attn_output_weights.mean(dim=1)
- if not is_batched:
- # squeeze the output if input was unbatched
- attn_output = attn_output.squeeze(1)
- attn_output_weights = attn_output_weights.squeeze(0)
- return attn_output, attn_output_weights
- else:
- # attn_mask can be either (L,S) or (N*num_heads, L, S)
- # if attn_mask's shape is (1, L, S) we need to unsqueeze to (1, 1, L, S)
- # in order to match the input for SDPA of (N, num_heads, L, S)
- if attn_mask is not None:
- if attn_mask.size(0) == 1 and attn_mask.dim() == 3:
- attn_mask = attn_mask.unsqueeze(0)
- else:
- attn_mask = attn_mask.view(bsz, num_heads, -1, src_len)
- q = q.view(bsz, num_heads, tgt_len, head_dim)
- k = k.view(bsz, num_heads, src_len, head_dim)
- v = v.view(bsz, num_heads, src_len, head_dim)
- attn_output = scaled_dot_product_attention(
- q, k, v, attn_mask, dropout_p, is_causal
- )
- attn_output = (
- attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
- )
- attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
- attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
- if not is_batched:
- # squeeze the output if input was unbatched
- attn_output = attn_output.squeeze(1)
- return attn_output, None
|