framework.py 285 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095309630973098309931003101310231033104310531063107310831093110311131123113311431153116311731183119312031213122312331243125312631273128312931303131313231333134313531363137313831393140314131423143314431453146314731483149315031513152315331543155315631573158315931603161316231633164316531663167316831693170317131723173317431753176317731783179318031813182318331843185318631873188318931903191319231933194319531963197319831993200320132023203320432053206320732083209321032113212321332143215321632173218321932203221322232233224322532263227322832293230323132323233323432353236323732383239324032413242324332443245324632473248324932503251325232533254325532563257325832593260326132623263326432653266326732683269327032713272327332743275327632773278327932803281328232833284328532863287328832893290329132923293329432953296329732983299330033013302330333043305330633073308330933103311331233133314331533163317331833193320332133223323332433253326332733283329333033313332333333343335333633373338333933403341334233433344334533463347334833493350335133523353335433553356335733583359336033613362336333643365336633673368336933703371337233733374337533763377337833793380338133823383338433853386338733883389339033913392339333943395339633973398339934003401340234033404340534063407340834093410341134123413341434153416341734183419342034213422342334243425342634273428342934303431343234333434343534363437343834393440344134423443344434453446344734483449345034513452345334543455345634573458345934603461346234633464346534663467346834693470347134723473347434753476347734783479348034813482348334843485348634873488348934903491349234933494349534963497349834993500350135023503350435053506350735083509351035113512351335143515351635173518351935203521352235233524352535263527352835293530353135323533353435353536353735383539354035413542354335443545354635473548354935503551355235533554355535563557355835593560356135623563356435653566356735683569357035713572357335743575357635773578357935803581358235833584358535863587358835893590359135923593359435953596359735983599360036013602360336043605360636073608360936103611361236133614361536163617361836193620362136223623362436253626362736283629363036313632363336343635363636373638363936403641364236433644364536463647364836493650365136523653365436553656365736583659366036613662366336643665366636673668366936703671367236733674367536763677367836793680368136823683368436853686368736883689369036913692369336943695369636973698369937003701370237033704370537063707370837093710371137123713371437153716371737183719372037213722372337243725372637273728372937303731373237333734373537363737373837393740374137423743374437453746374737483749375037513752375337543755375637573758375937603761376237633764376537663767376837693770377137723773377437753776377737783779378037813782378337843785378637873788378937903791379237933794379537963797379837993800380138023803380438053806380738083809381038113812381338143815381638173818381938203821382238233824382538263827382838293830383138323833383438353836383738383839384038413842384338443845384638473848384938503851385238533854385538563857385838593860386138623863386438653866386738683869387038713872387338743875387638773878387938803881388238833884388538863887388838893890389138923893389438953896389738983899390039013902390339043905390639073908390939103911391239133914391539163917391839193920392139223923392439253926392739283929393039313932393339343935393639373938393939403941394239433944394539463947394839493950395139523953395439553956395739583959396039613962396339643965396639673968396939703971397239733974397539763977397839793980398139823983398439853986398739883989399039913992399339943995399639973998399940004001400240034004400540064007400840094010401140124013401440154016401740184019402040214022402340244025402640274028402940304031403240334034403540364037403840394040404140424043404440454046404740484049405040514052405340544055405640574058405940604061406240634064406540664067406840694070407140724073407440754076407740784079408040814082408340844085408640874088408940904091409240934094409540964097409840994100410141024103410441054106410741084109411041114112411341144115411641174118411941204121412241234124412541264127412841294130413141324133413441354136413741384139414041414142414341444145414641474148414941504151415241534154415541564157415841594160416141624163416441654166416741684169417041714172417341744175417641774178417941804181418241834184418541864187418841894190419141924193419441954196419741984199420042014202420342044205420642074208420942104211421242134214421542164217421842194220422142224223422442254226422742284229423042314232423342344235423642374238423942404241424242434244424542464247424842494250425142524253425442554256425742584259426042614262426342644265426642674268426942704271427242734274427542764277427842794280428142824283428442854286428742884289429042914292429342944295429642974298429943004301430243034304430543064307430843094310431143124313431443154316431743184319432043214322432343244325432643274328432943304331433243334334433543364337433843394340434143424343434443454346434743484349435043514352435343544355435643574358435943604361436243634364436543664367436843694370437143724373437443754376437743784379438043814382438343844385438643874388438943904391439243934394439543964397439843994400440144024403440444054406440744084409441044114412441344144415441644174418441944204421442244234424442544264427442844294430443144324433443444354436443744384439444044414442444344444445444644474448444944504451445244534454445544564457445844594460446144624463446444654466446744684469447044714472447344744475447644774478447944804481448244834484448544864487448844894490449144924493449444954496449744984499450045014502450345044505450645074508450945104511451245134514451545164517451845194520452145224523452445254526452745284529453045314532453345344535453645374538453945404541454245434544454545464547454845494550455145524553455445554556455745584559456045614562456345644565456645674568456945704571457245734574457545764577457845794580458145824583458445854586458745884589459045914592459345944595459645974598459946004601460246034604460546064607460846094610461146124613461446154616461746184619462046214622462346244625462646274628462946304631463246334634463546364637463846394640464146424643464446454646464746484649465046514652465346544655465646574658465946604661466246634664466546664667466846694670467146724673467446754676467746784679468046814682468346844685468646874688468946904691469246934694469546964697469846994700470147024703470447054706470747084709471047114712471347144715471647174718471947204721472247234724472547264727472847294730473147324733473447354736473747384739474047414742474347444745474647474748474947504751475247534754475547564757475847594760476147624763476447654766476747684769477047714772477347744775477647774778477947804781478247834784478547864787478847894790479147924793479447954796479747984799480048014802480348044805480648074808480948104811481248134814481548164817481848194820482148224823482448254826482748284829483048314832483348344835483648374838483948404841484248434844484548464847484848494850485148524853485448554856485748584859486048614862486348644865486648674868486948704871487248734874487548764877487848794880488148824883488448854886488748884889489048914892489348944895489648974898489949004901490249034904490549064907490849094910491149124913491449154916491749184919492049214922492349244925492649274928492949304931493249334934493549364937493849394940494149424943494449454946494749484949495049514952495349544955495649574958495949604961496249634964496549664967496849694970497149724973497449754976497749784979498049814982498349844985498649874988498949904991499249934994499549964997499849995000500150025003500450055006500750085009501050115012501350145015501650175018501950205021502250235024502550265027502850295030503150325033503450355036503750385039504050415042504350445045504650475048504950505051505250535054505550565057505850595060506150625063506450655066506750685069507050715072507350745075507650775078507950805081508250835084508550865087508850895090509150925093509450955096509750985099510051015102510351045105510651075108510951105111511251135114511551165117511851195120512151225123512451255126512751285129513051315132513351345135513651375138513951405141514251435144514551465147514851495150515151525153515451555156515751585159516051615162516351645165516651675168516951705171517251735174517551765177517851795180518151825183518451855186518751885189519051915192519351945195519651975198519952005201520252035204520552065207520852095210521152125213521452155216521752185219522052215222522352245225522652275228522952305231523252335234523552365237523852395240524152425243524452455246524752485249525052515252525352545255525652575258525952605261526252635264526552665267526852695270527152725273527452755276527752785279528052815282528352845285528652875288528952905291529252935294529552965297529852995300530153025303530453055306530753085309531053115312531353145315531653175318531953205321532253235324532553265327532853295330533153325333533453355336533753385339534053415342534353445345534653475348534953505351535253535354535553565357535853595360536153625363536453655366536753685369537053715372537353745375537653775378537953805381538253835384538553865387538853895390539153925393539453955396539753985399540054015402540354045405540654075408540954105411541254135414541554165417541854195420542154225423542454255426542754285429543054315432543354345435543654375438543954405441544254435444544554465447544854495450545154525453545454555456545754585459546054615462546354645465546654675468546954705471547254735474547554765477547854795480548154825483548454855486548754885489549054915492549354945495549654975498549955005501550255035504550555065507550855095510551155125513551455155516551755185519552055215522552355245525552655275528552955305531553255335534553555365537553855395540554155425543554455455546554755485549555055515552555355545555555655575558555955605561556255635564556555665567556855695570557155725573557455755576557755785579558055815582558355845585558655875588558955905591559255935594559555965597559855995600560156025603560456055606560756085609561056115612561356145615561656175618561956205621562256235624562556265627562856295630563156325633563456355636563756385639564056415642564356445645564656475648564956505651565256535654565556565657565856595660566156625663566456655666566756685669567056715672567356745675567656775678567956805681568256835684568556865687568856895690569156925693569456955696569756985699570057015702570357045705570657075708570957105711571257135714571557165717571857195720572157225723572457255726572757285729573057315732573357345735573657375738573957405741574257435744574557465747574857495750575157525753575457555756575757585759576057615762576357645765576657675768576957705771577257735774577557765777577857795780578157825783578457855786578757885789579057915792579357945795579657975798579958005801580258035804580558065807580858095810581158125813581458155816581758185819582058215822582358245825582658275828582958305831583258335834583558365837583858395840584158425843584458455846584758485849585058515852585358545855585658575858585958605861586258635864586558665867586858695870587158725873587458755876587758785879588058815882588358845885588658875888588958905891589258935894589558965897589858995900590159025903590459055906590759085909591059115912591359145915591659175918591959205921592259235924592559265927592859295930593159325933593459355936593759385939594059415942594359445945594659475948594959505951595259535954595559565957595859595960596159625963596459655966596759685969597059715972597359745975597659775978597959805981598259835984598559865987598859895990599159925993599459955996599759985999600060016002600360046005600660076008600960106011601260136014601560166017601860196020602160226023602460256026602760286029603060316032603360346035603660376038603960406041604260436044604560466047604860496050605160526053605460556056605760586059606060616062606360646065606660676068606960706071607260736074607560766077607860796080608160826083608460856086608760886089609060916092609360946095609660976098609961006101610261036104610561066107610861096110611161126113611461156116611761186119612061216122612361246125612661276128612961306131613261336134613561366137613861396140614161426143614461456146614761486149615061516152615361546155615661576158615961606161616261636164616561666167616861696170617161726173617461756176617761786179618061816182618361846185618661876188618961906191619261936194619561966197619861996200620162026203620462056206620762086209621062116212621362146215621662176218621962206221622262236224622562266227622862296230623162326233623462356236623762386239624062416242624362446245624662476248624962506251625262536254625562566257625862596260626162626263626462656266626762686269627062716272627362746275627662776278627962806281628262836284628562866287628862896290629162926293629462956296629762986299630063016302630363046305630663076308630963106311631263136314631563166317631863196320632163226323632463256326632763286329633063316332633363346335633663376338633963406341634263436344634563466347634863496350635163526353635463556356635763586359636063616362636363646365636663676368636963706371637263736374637563766377637863796380638163826383638463856386638763886389639063916392639363946395639663976398639964006401640264036404640564066407640864096410641164126413641464156416641764186419642064216422642364246425642664276428642964306431643264336434643564366437643864396440644164426443644464456446644764486449645064516452645364546455645664576458645964606461646264636464646564666467646864696470647164726473647464756476647764786479648064816482648364846485648664876488648964906491649264936494649564966497649864996500650165026503650465056506650765086509651065116512651365146515651665176518651965206521652265236524652565266527652865296530653165326533653465356536653765386539654065416542654365446545654665476548654965506551655265536554655565566557655865596560656165626563656465656566656765686569657065716572657365746575657665776578657965806581658265836584658565866587658865896590659165926593659465956596659765986599660066016602660366046605660666076608660966106611661266136614661566166617661866196620662166226623662466256626662766286629663066316632663366346635663666376638663966406641664266436644664566466647664866496650665166526653665466556656665766586659666066616662666366646665666666676668666966706671667266736674667566766677667866796680668166826683668466856686668766886689669066916692669366946695669666976698669967006701670267036704670567066707670867096710671167126713671467156716671767186719672067216722672367246725672667276728672967306731673267336734673567366737673867396740674167426743674467456746674767486749675067516752675367546755675667576758675967606761676267636764676567666767676867696770677167726773677467756776677767786779678067816782678367846785678667876788678967906791679267936794679567966797679867996800680168026803680468056806680768086809681068116812681368146815681668176818681968206821682268236824682568266827682868296830683168326833683468356836683768386839684068416842684368446845684668476848684968506851685268536854685568566857685868596860686168626863686468656866686768686869687068716872687368746875687668776878687968806881688268836884688568866887688868896890689168926893689468956896689768986899690069016902690369046905690669076908690969106911691269136914691569166917691869196920692169226923692469256926692769286929693069316932693369346935693669376938693969406941694269436944694569466947694869496950695169526953695469556956695769586959696069616962696369646965696669676968696969706971697269736974697569766977697869796980698169826983698469856986698769886989699069916992699369946995699669976998699970007001700270037004700570067007700870097010701170127013701470157016701770187019702070217022702370247025702670277028702970307031703270337034703570367037703870397040704170427043704470457046704770487049705070517052705370547055705670577058705970607061706270637064706570667067706870697070707170727073707470757076707770787079708070817082708370847085708670877088708970907091709270937094709570967097709870997100710171027103710471057106710771087109711071117112711371147115711671177118711971207121712271237124712571267127712871297130713171327133713471357136713771387139714071417142714371447145714671477148714971507151715271537154715571567157715871597160716171627163716471657166716771687169717071717172717371747175717671777178717971807181718271837184718571867187718871897190719171927193719471957196719771987199720072017202720372047205720672077208720972107211721272137214721572167217721872197220722172227223722472257226722772287229723072317232723372347235723672377238723972407241724272437244724572467247724872497250725172527253725472557256725772587259726072617262726372647265726672677268726972707271727272737274727572767277727872797280728172827283728472857286728772887289729072917292729372947295729672977298729973007301730273037304730573067307730873097310731173127313731473157316731773187319732073217322732373247325732673277328732973307331733273337334733573367337733873397340734173427343734473457346734773487349735073517352735373547355735673577358735973607361736273637364736573667367736873697370737173727373737473757376737773787379738073817382738373847385738673877388738973907391739273937394739573967397739873997400740174027403740474057406740774087409741074117412741374147415741674177418741974207421742274237424742574267427742874297430743174327433743474357436743774387439744074417442744374447445744674477448744974507451745274537454745574567457745874597460746174627463746474657466746774687469747074717472747374747475747674777478747974807481748274837484748574867487748874897490749174927493749474957496749774987499750075017502750375047505750675077508750975107511751275137514751575167517751875197520752175227523752475257526752775287529753075317532753375347535753675377538753975407541754275437544754575467547754875497550755175527553755475557556755775587559756075617562756375647565756675677568756975707571757275737574757575767577757875797580758175827583758475857586758775887589759075917592759375947595759675977598759976007601760276037604760576067607760876097610761176127613761476157616761776187619762076217622762376247625762676277628762976307631763276337634763576367637763876397640764176427643764476457646764776487649765076517652765376547655765676577658765976607661766276637664766576667667766876697670767176727673767476757676767776787679768076817682768376847685768676877688768976907691769276937694769576967697769876997700770177027703770477057706770777087709771077117712771377147715771677177718771977207721772277237724772577267727772877297730773177327733773477357736773777387739774077417742774377447745774677477748774977507751775277537754775577567757775877597760776177627763776477657766776777687769777077717772777377747775777677777778777977807781778277837784778577867787778877897790779177927793779477957796779777987799780078017802780378047805780678077808780978107811781278137814781578167817781878197820782178227823782478257826782778287829783078317832783378347835783678377838783978407841784278437844784578467847784878497850785178527853785478557856785778587859786078617862786378647865786678677868786978707871787278737874787578767877787878797880788178827883788478857886788778887889789078917892789378947895789678977898789979007901790279037904790579067907790879097910791179127913791479157916791779187919792079217922792379247925792679277928792979307931793279337934793579367937793879397940794179427943794479457946794779487949795079517952795379547955795679577958795979607961796279637964796579667967796879697970797179727973797479757976797779787979798079817982798379847985798679877988798979907991799279937994799579967997799879998000800180028003800480058006800780088009801080118012801380148015801680178018801980208021802280238024802580268027802880298030803180328033803480358036803780388039804080418042804380448045804680478048804980508051805280538054805580568057805880598060806180628063806480658066806780688069807080718072807380748075807680778078807980808081808280838084808580868087808880898090809180928093809480958096809780988099810081018102810381048105810681078108810981108111811281138114811581168117811881198120812181228123812481258126812781288129813081318132813381348135813681378138813981408141814281438144814581468147814881498150815181528153815481558156815781588159816081618162816381648165816681678168816981708171817281738174817581768177817881798180818181828183818481858186818781888189819081918192819381948195819681978198819982008201820282038204820582068207820882098210821182128213821482158216821782188219822082218222822382248225822682278228822982308231823282338234823582368237823882398240824182428243824482458246824782488249825082518252825382548255825682578258825982608261826282638264826582668267826882698270827182728273827482758276827782788279828082818282828382848285828682878288828982908291829282938294829582968297829882998300830183028303830483058306830783088309831083118312831383148315831683178318831983208321832283238324832583268327832883298330
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import annotations
  15. import collections
  16. import copy
  17. import functools
  18. import multiprocessing
  19. import os
  20. import re
  21. import subprocess
  22. import sys
  23. import textwrap
  24. import threading
  25. import traceback
  26. import warnings
  27. from collections.abc import Iterable
  28. from types import FunctionType, MethodType
  29. from typing import TYPE_CHECKING
  30. import numpy as np
  31. import paddle
  32. import paddle.version as paddle_version
  33. from .. import pir
  34. from . import core, unique_name
  35. from .libpaddle import DataType
  36. from .proto import (
  37. data_feed_pb2, # noqa: F401
  38. framework_pb2,
  39. )
  40. from .variable_index import _getitem_static, _setitem_static
  41. from .wrapped_decorator import signature_safe_contextmanager, wrap_decorator
  42. if TYPE_CHECKING:
  43. from paddle.static.amp.fp16_utils import AmpOptions
  44. __all__ = []
  45. EMPTY_VAR_NAME = core.kEmptyVarName()
  46. TEMP_VAR_NAME = core.kTempVarName()
  47. GRAD_VAR_SUFFIX = core.kGradVarSuffix()
  48. ZERO_VAR_SUFFIX = core.kZeroVarSuffix()
  49. CONTROL_DEP_VAR_PREFIX = core.kControlDepVarName()
  50. _global_flags_ = core.globals()
  51. SUPPORT_PROMOTION_OPS_AND_INPUTNAME = {
  52. "elementwise_add": ['X', 'Y'],
  53. "elementwise_add_grad": ['X', 'Y'],
  54. "elementwise_sub": ['X', 'Y'],
  55. "elementwise_sub_grad": ['X', 'Y'],
  56. "elementwise_mul": ['X', 'Y'],
  57. "elementwise_mul_grad": ['X', 'Y'],
  58. "elementwise_div": ['X', 'Y'],
  59. "elementwise_div_grad": ['X', 'Y'],
  60. "elementwise_floordiv": ['X', 'Y'],
  61. "elementwise_floordiv_grad": ['X', 'Y'],
  62. "elementwise_pow": ['X', 'Y'],
  63. "elementwise_pow_grad": ['X', 'Y'],
  64. "where": ['X', 'Y'],
  65. "where_grad": ['X', 'Y'],
  66. "equal": ['X', 'Y'],
  67. "not_equal": ['X', 'Y'],
  68. "less_than": ['X', 'Y'],
  69. "less_equal": ['X', 'Y'],
  70. "greater_than": ['X', 'Y'],
  71. "greater_equal": ['X', 'Y'],
  72. "logical_and": ['X', 'Y'],
  73. "logical_or": ['X', 'Y'],
  74. "logical_xor": ['X', 'Y'],
  75. "elementwise_fmax": ['X', 'Y'],
  76. "elementwise_fmax_grad": ['X', 'Y'],
  77. "elementwise_fmin": ['X', 'Y'],
  78. "elementwise_fmin_grad": ['X', 'Y'],
  79. "elementwise_max": ['X', 'Y'],
  80. "elementwise_max_grad": ['X', 'Y'],
  81. "elementwise_min": ['X', 'Y'],
  82. "elementwise_min_grad": ['X', 'Y'],
  83. "elementwise_mod": ['X', 'Y'],
  84. "elementwise_mod_grad": ['X', 'Y'],
  85. "huber_loss": ['X', 'Y'],
  86. "huber_loss_grad": ['X', 'Y'],
  87. "nextafter": ['x', 'y'],
  88. "atan2": ['X1', 'X2'],
  89. "atan2_grad": ['X1', 'X2'],
  90. }
  91. def _global_flags():
  92. return _global_flags_
  93. def set_flags(flags):
  94. """
  95. This function sets the GFlags value in Paddle.
  96. For FLAGS please refer to :ref:`en_guides_flags_flags`
  97. Args:
  98. flags (dict): A dict contains flags and its value.
  99. Examples:
  100. .. code-block:: python
  101. >>> import paddle
  102. >>> paddle.set_flags({'FLAGS_eager_delete_tensor_gb': 1.0})
  103. """
  104. if not isinstance(flags, dict):
  105. raise TypeError("flags in set_flags should be a dict")
  106. for key, value in flags.items():
  107. if _global_flags().is_public(key):
  108. _global_flags()[key] = value
  109. else:
  110. raise ValueError(
  111. "Flag %s cannot set its value through this function." % (key)
  112. )
  113. def get_flags(flags):
  114. """
  115. This function gets the GFlags value in Paddle.
  116. For FLAGS please refer to :ref:`en_guides_flags_flags`
  117. Args:
  118. flags(list|tuple|str): A list/tuple of string or a string which is the flag's name.
  119. Returns:
  120. flag's value in Paddle.
  121. Examples:
  122. .. code-block:: python
  123. >>> import paddle
  124. >>> flags = ['FLAGS_eager_delete_tensor_gb', 'FLAGS_check_nan_inf']
  125. >>> res = paddle.get_flags(flags)
  126. >>> print(res)
  127. {'FLAGS_eager_delete_tensor_gb': 0.0, 'FLAGS_check_nan_inf': False}
  128. """
  129. flags_value = {}
  130. if isinstance(flags, (list, tuple)):
  131. for key in flags:
  132. if _global_flags().is_public(key):
  133. value = _global_flags()[key]
  134. temp = {key: value}
  135. flags_value.update(temp)
  136. else:
  137. raise ValueError(
  138. "Flag %s cannot get its value through this function."
  139. % (key)
  140. )
  141. elif isinstance(flags, str):
  142. if _global_flags().is_public(flags):
  143. value = _global_flags()[flags]
  144. temp = {flags: value}
  145. flags_value.update(temp)
  146. else:
  147. raise ValueError(
  148. "Flag %s cannot get its value through this function." % (flags)
  149. )
  150. else:
  151. raise TypeError("Flags in get_flags should be a list, tuple or string.")
  152. return flags_value
  153. # use thread local to create thread save global variables.
  154. class GlobalThreadLocal(threading.local):
  155. def __init__(self):
  156. """
  157. init the thread local data.
  158. TODO(xiongkun): how to access another thread local data ?
  159. """
  160. global _dygraph_tracer_
  161. self._in_to_static_mode_ = False
  162. self._functional_dygraph_context_manager = None
  163. self._dygraph_tracer_ = _dygraph_tracer_
  164. self._use_pir_api_ = get_flags("FLAGS_enable_pir_api")[
  165. "FLAGS_enable_pir_api"
  166. ]
  167. def __str__(self):
  168. strings = []
  169. strings.append("_in_to_static_mode_:" + str(self._in_to_static_mode_))
  170. strings.append(
  171. "_functional_dygraph_context_manager:"
  172. + str(self._functional_dygraph_context_manager)
  173. )
  174. strings.append("_dygraph_tracer_:" + str(self._dygraph_tracer_))
  175. return "\n".join(strings)
  176. def __setattr__(self, name, val):
  177. if name == "_dygraph_tracer_":
  178. global _dygraph_tracer_
  179. _dygraph_tracer_ = val
  180. core._switch_tracer(val)
  181. self.__dict__[name] = val
  182. _dygraph_tracer_ = None
  183. global_var = GlobalThreadLocal()
  184. _global_expected_place_ = None
  185. _current_device = None
  186. global_prog_seed = 0
  187. _current_pipeline_stage = None
  188. _current_cuda_graph_mode = None
  189. _stride_in_no_check_dy2st_diff_mode = False
  190. # special_op_attrs, extra_op_attrs are prepared for printing warnings
  191. # when turning on FLAGS_print_extra_attrs
  192. special_op_attrs = {
  193. "elementwise_add": [{"axis": -1}],
  194. "elementwise_sub": [{"axis": -1}],
  195. "elementwise_mul": [{"axis": -1}],
  196. "elementwise_div": [{"axis": -1}],
  197. "elementwise_max": [{"axis": -1}],
  198. "elementwise_min": [{"axis": -1}],
  199. "elementwise_pow": [{"axis": -1}],
  200. "elementwise_mod": [{"axis": -1}],
  201. "elementwise_floordiv": [{"axis": -1}],
  202. "less_than": [{"axis": -1}],
  203. "less_equal": [{"axis": -1}],
  204. "greater_than": [{"axis": -1}],
  205. "greater_equal": [{"axis": -1}],
  206. "equal": [{"axis": -1}],
  207. "not_equal": [{"axis": -1}],
  208. "amax": [{"reduce_all": False}],
  209. "amin": [{"reduce_all": False}],
  210. "any": [{"reduce_all": False}],
  211. "frobenius_norm": [{"reduce_all": False}],
  212. "logsumexp": [{"reduce_all": False}],
  213. "reduce_max": [{"reduce_all": False}],
  214. "reduce_min": [{"reduce_all": False}],
  215. "reduce_mean": [{"reduce_all": False}],
  216. "reduce_prod": [{"reduce_all": False}],
  217. "reduce_sum": [{"reduce_all": False}],
  218. }
  219. extra_op_attrs = {
  220. "gather": ["overwrite"],
  221. "graph_reindex": ["flag_buffer_hashtable"],
  222. "graph_sample_neighbors": ["flag_perm_buffer"],
  223. "relu6": ["threshold"],
  224. "swish": ["beta"],
  225. "hsigmoid_loss": ["remote_prefetch"],
  226. "max_pool2d_with_index": ["global_pooling"],
  227. "uniform": ["diag_num"],
  228. "unique": ["is_sorted"],
  229. }
  230. paddle_type_to_proto_type = {
  231. DataType.BOOL: core.VarDesc.VarType.BOOL,
  232. DataType.FLOAT16: core.VarDesc.VarType.FP16,
  233. DataType.UINT16: core.VarDesc.VarType.BF16,
  234. DataType.BFLOAT16: core.VarDesc.VarType.BF16,
  235. DataType.FLOAT32: core.VarDesc.VarType.FP32,
  236. DataType.FLOAT64: core.VarDesc.VarType.FP64,
  237. DataType.INT8: core.VarDesc.VarType.INT8,
  238. DataType.INT16: core.VarDesc.VarType.INT16,
  239. DataType.INT32: core.VarDesc.VarType.INT32,
  240. DataType.INT64: core.VarDesc.VarType.INT64,
  241. DataType.UINT8: core.VarDesc.VarType.UINT8,
  242. DataType.COMPLEX64: core.VarDesc.VarType.COMPLEX64,
  243. DataType.COMPLEX128: core.VarDesc.VarType.COMPLEX128,
  244. }
  245. def in_dygraph_mode():
  246. """
  247. .. note::
  248. Dynamic graph mode is turn ON by default since paddle 2.0.0
  249. This API checks whether paddle runs in dynamic graph mode.
  250. You can turn ON static graph mode by `enable_static <../dygraph/base/disable_dygraph_en.html>`_ ,
  251. and turn OFF static graph mode by `disable_static <../dygraph/base/enable_dygraph_en.html>`_ .
  252. Returns:
  253. bool: Whether paddle runs in dynamic graph mode.
  254. Examples:
  255. .. code-block:: python
  256. >>> import paddle
  257. >>> print(paddle.in_dynamic_mode()) # dynamic mode is turn ON by default since paddle 2.0.
  258. True
  259. >>> paddle.enable_static()
  260. >>> print(paddle.in_dynamic_mode()) # Now we are in static graph mode
  261. False
  262. >>> paddle.disable_static()
  263. >>> print(paddle.in_dynamic_mode()) # Now we are in dynamic mode
  264. True
  265. """
  266. return global_var._dygraph_tracer_ is not None
  267. def in_pir_mode():
  268. """
  269. This API checks whether paddle runs in static graph mode and use pir api.
  270. Returns:
  271. bool: Whether paddle runs in static graph mode and use pir api.
  272. Examples:
  273. .. code-block:: python
  274. >>> import paddle
  275. >>> print(paddle.framework.in_pir_mode())
  276. False
  277. >>> paddle.enable_static()
  278. >>> with paddle.pir_utils.IrGuard():
  279. ... print(paddle.framework.in_pir_mode())
  280. True
  281. """
  282. return global_var._use_pir_api_ and not in_dygraph_mode()
  283. def use_pir_api():
  284. return global_var._use_pir_api_
  285. def in_dynamic_or_pir_mode():
  286. """
  287. This API checks whether paddle runs in dynamic graph or pir mode.
  288. Returns:
  289. bool: Whether paddle runs in static graph mode and use pir api.
  290. Examples:
  291. .. code-block:: python
  292. >>> import paddle
  293. >>> print(paddle.framework.in_dynamic_or_pir_mode())
  294. True
  295. >>> paddle.enable_static()
  296. >>> print(paddle.framework.in_dynamic_or_pir_mode())
  297. False
  298. >>> with paddle.pir_utils.IrGuard():
  299. ... print(paddle.framework.in_dynamic_or_pir_mode())
  300. True
  301. """
  302. return global_var._dygraph_tracer_ is not None or global_var._use_pir_api_
  303. def in_pir_executor_mode():
  304. """
  305. This API checks whether paddle runs in pir executor mode.
  306. Returns:
  307. bool: Whether paddle runs in pir executor mode.
  308. """
  309. flag = str(os.environ.get("FLAGS_enable_pir_in_executor")).lower()
  310. return flag in ("true", "1")
  311. def in_cinn_mode():
  312. """
  313. This API checks whether paddle runs in cinn mode.
  314. Returns:
  315. bool: Whether paddle runs in cinn mode.
  316. """
  317. flag = str(os.environ.get("FLAGS_use_cinn")).lower()
  318. return flag in ("true", "1")
  319. global_ipu_index = -1
  320. global_ipu_stage = -1
  321. ipu_index_attr_name = "ipu_index"
  322. ipu_stage_attr_name = "ipu_stage"
  323. @signature_safe_contextmanager
  324. def ipu_shard_guard(index=-1, stage=-1):
  325. """
  326. Used to shard the graph on IPUs. Set each Op run on which IPU in the sharding and which stage in the pipelining.
  327. Args:
  328. index(int, optional): Specify which ipu the Tensor is computed on, (such as '0, 1, 2, 3').
  329. The default value is -1, which means the Op only run on IPU 0.
  330. stage(int, optional): Specify the computation order of the sharded model(such as '0, 1, 2, 3').
  331. The sharded model will be computed from small to large. The default value is -1,
  332. which means no pipelining computation order and run Ops in terms of graph.
  333. Note:
  334. Only if the enable_manual_shard=True, the 'index' is able to be set not -1. Please refer
  335. to :ref:`api_paddle_static_IpuStrategy`.
  336. Only if the enable_pipelining=True, the 'stage' is able to be set not -1. Please refer
  337. to :ref:`api_paddle_static_IpuStrategy`.
  338. A index is allowed to match none stage or a stage. A stage is only allowed to match a new or
  339. duplicated index.
  340. Examples:
  341. .. code-block:: python
  342. >>> # doctest: +REQUIRES(env:IPU)
  343. >>> import paddle
  344. >>> paddle.device.set_device('ipu')
  345. >>> paddle.enable_static()
  346. >>> a = paddle.static.data(name='data', shape=[None, 1], dtype='int32')
  347. >>> with paddle.static.ipu_shard_guard(index=0, stage=0):
  348. ... b = a + 1
  349. >>> with paddle.static.ipu_shard_guard(index=1, stage=1):
  350. ... c = b + 1
  351. >>> with paddle.static.ipu_shard_guard(index=0, stage=2):
  352. ... d = c + 1
  353. """
  354. if not core.is_compiled_with_ipu():
  355. raise ValueError(
  356. "Can not use this function since PaddlePaddle is not compiled with IPU"
  357. )
  358. global global_ipu_index
  359. global global_ipu_stage
  360. prev_ipu_index = global_ipu_index
  361. prev_ipu_stage = global_ipu_stage
  362. global_ipu_index = index
  363. global_ipu_stage = stage
  364. try:
  365. yield
  366. finally:
  367. global_ipu_index = prev_ipu_index
  368. global_ipu_stage = prev_ipu_stage
  369. def set_ipu_shard(call_func, index=-1, stage=-1):
  370. """
  371. Shard the ipu with the given call function. Set every ops in call function to the given ipu sharding.
  372. Note:
  373. Only when enable_manual_shard=True to set the index to a value other than -1. please refer to :ref:`api_paddle_static_IpuStrategy` .
  374. Only when enable_pipelining=True to set stage to a value other than -1. please refer to :ref:`api_paddle_static_IpuStrategy` .
  375. An index supports a corresponding None stage or a stage, and a stage only supports a new index or a duplicate index.
  376. Args:
  377. call_func(Layer|function): Specify the call function to be wrapped.
  378. index(int, optional): Specify which ipu the Tensor is computed on, (such as ‘0, 1, 2, 3’).
  379. The default value is -1, which means the Op only run on IPU 0.
  380. stage(int, optional): Specify the computation order of the sharded model(such as ‘0, 1, 2, 3’).
  381. The sharded model will be computed from small to large. The default value is -1,
  382. which means no pipelining computation order and run Ops in terms of graph.
  383. Returns:
  384. The wrapped call function.
  385. Examples:
  386. .. code-block:: python
  387. >>> # doctest: +REQUIRES(env:IPU)
  388. >>> import paddle
  389. >>> paddle.device.set_device('ipu')
  390. >>> paddle.enable_static()
  391. >>> a = paddle.static.data(name='data', shape=[None, 1], dtype='float32')
  392. >>> relu = paddle.nn.ReLU()
  393. >>> relu = paddle.static.set_ipu_shard(relu, index=1, stage=1)
  394. >>> relu(a)
  395. """
  396. def decorate(func):
  397. def wrapper(*args, **kwargs):
  398. with ipu_shard_guard(index=index, stage=stage):
  399. return func(*args, **kwargs)
  400. return wrapper
  401. from paddle.nn import Layer
  402. if not isinstance(call_func, Layer):
  403. if callable(call_func):
  404. return decorate(call_func)
  405. else:
  406. raise TypeError(
  407. "Unsupported type. Only accept paddle.nn.Layer or function."
  408. )
  409. # patch paddle.nn.Layer
  410. class BlockFn(type(call_func)):
  411. def __call__(self, *args, **kwargs):
  412. with ipu_shard_guard(index=index, stage=stage):
  413. return super().__call__(*args, **kwargs)
  414. BlockFn.__name__ = type(call_func).__name__
  415. call_func.__class__ = BlockFn
  416. return call_func
  417. def require_version(min_version, max_version=None):
  418. """
  419. Check if the installed version of PaddlePaddle is in [min_version, max_version],
  420. if the installed version is lower than ``min_version`` or higher than ``max_version``,
  421. an exception will be thrown, NO returns if the installed version is satisfied.
  422. Args:
  423. min_version (str): the minimum version required (like '1.4.0').
  424. max_version (str, optional): the max version required (like '1.6.0'), default is None,
  425. meaning any version equal or higher than ``min_version`` is acceptable.
  426. Returns:
  427. None.
  428. Raises:
  429. TypeError: if the type of ``min_version`` is not str.
  430. TypeError: if the type of ``max_version`` is not str or type(None).
  431. ValueError: if the value of ``min_version`` is not in version format.
  432. ValueError: if the value of ``max_version`` is not in version format or None.
  433. Exception: if the installed version is lower than ``min_version`` or higher than ``max_version``.
  434. Examples:
  435. .. code-block:: python
  436. >>> import paddle
  437. >>> # any version >= 0.1.0 is acceptable.
  438. >>> paddle.utils.require_version('0.1.0')
  439. >>> # if 0.1.0 <= version <= 10.0.0, it is acceptable.
  440. >>> paddle.utils.require_version(min_version='0.1.0', max_version='10.0.0')
  441. """
  442. if not isinstance(min_version, str):
  443. raise TypeError(
  444. "The type of 'min_version' in require_version must be str, but received %s."
  445. % (type(min_version))
  446. )
  447. if not isinstance(max_version, (str, type(None))):
  448. raise TypeError(
  449. "The type of 'max_version' in require_version must be str or type(None), but received %s."
  450. % (type(max_version))
  451. )
  452. check_format = re.match(r"\d+(\.\d+){0,3}", min_version)
  453. if check_format is None or check_format.group() != min_version:
  454. raise ValueError(
  455. "The value of 'min_version' in require_version must be in format '\\d+(\\.\\d+){0,3}', "
  456. "like '1.5.2.0', but received %s" % min_version
  457. )
  458. if max_version is not None:
  459. check_format = re.match(r"\d+(\.\d+){0,3}", max_version)
  460. if check_format is None or check_format.group() != max_version:
  461. raise ValueError(
  462. "The value of 'max_version' in require_version must be in format '\\d+(\\.\\d+){0,3}', "
  463. "like '1.5.2.0', but received %s" % max_version
  464. )
  465. version_installed = [
  466. paddle_version.major,
  467. paddle_version.minor,
  468. paddle_version.patch,
  469. paddle_version.rc,
  470. ]
  471. zero_version = ["0", "0", "0", "0"]
  472. def version_cmp(ver_a, ver_b):
  473. for i in range(len(ver_a)):
  474. if int(ver_a[i]) > int(ver_b[i]):
  475. return 1
  476. elif int(ver_a[i]) < int(ver_b[i]):
  477. return -1
  478. return 0
  479. if version_cmp(version_installed, zero_version) == 0:
  480. if max_version is not None:
  481. warnings.warn(
  482. f"PaddlePaddle version in [{min_version}, {max_version}] required, but {paddle_version.full_version} installed. "
  483. "Maybe you are using a develop version, "
  484. "please make sure the version is good with your code."
  485. )
  486. else:
  487. warnings.warn(
  488. f"PaddlePaddle version {min_version} or higher is required, but {paddle_version.full_version} installed, "
  489. "Maybe you are using a develop version, "
  490. "please make sure the version is good with your code."
  491. )
  492. return
  493. min_version_split = min_version.split(".")
  494. min_version_to_check = (
  495. min_version_split + zero_version[len(min_version_split) :]
  496. )
  497. if max_version is not None:
  498. max_version_split = max_version.split(".")
  499. max_version_to_check = (
  500. max_version_split + zero_version[len(max_version_split) :]
  501. )
  502. if (
  503. version_cmp(version_installed, max_version_to_check) > 0
  504. or version_cmp(version_installed, min_version_to_check) < 0
  505. ):
  506. raise Exception(
  507. f"VersionError: PaddlePaddle version in [{min_version}, {max_version}] required, but {paddle_version.full_version} installed."
  508. )
  509. else:
  510. if version_cmp(version_installed, min_version_to_check) < 0:
  511. raise Exception(
  512. f"VersionError: PaddlePaddle version {min_version} or higher is required, but {paddle_version.full_version} installed, "
  513. f"please upgrade your PaddlePaddle to {min_version} or other higher version."
  514. )
  515. def _dygraph_not_support_(func):
  516. def __impl__(*args, **kwargs):
  517. assert not in_dygraph_mode(), (
  518. "We don't support %s in dynamic graph mode" % func.__name__
  519. )
  520. return func(*args, **kwargs)
  521. return __impl__
  522. def _dygraph_only_(func):
  523. def __impl__(*args, **kwargs):
  524. assert in_dygraph_mode(), (
  525. "We only support '%s()' in dynamic graph mode, please call 'paddle.disable_static()' to enter dynamic graph mode."
  526. % func.__name__
  527. )
  528. return func(*args, **kwargs)
  529. return __impl__
  530. def _non_static_only_(func):
  531. def __impl__(*args, **kwargs):
  532. from .dygraph.base import in_to_static_mode
  533. assert in_dygraph_mode() or in_to_static_mode(), (
  534. "We only support '%s()' in dynamic graph mode, please call 'paddle.disable_static()' to enter dynamic graph mode."
  535. % func.__name__
  536. )
  537. return func(*args, **kwargs)
  538. return __impl__
  539. def _static_only_(func):
  540. def __impl__(*args, **kwargs):
  541. assert not in_dygraph_mode(), (
  542. "In PaddlePaddle 2.x, we turn on dynamic graph mode by default, and '%s()' is only supported in static graph mode. So if you want to use this api, please call 'paddle.enable_static()' before this api to enter static graph mode."
  543. % func.__name__
  544. )
  545. return func(*args, **kwargs)
  546. return __impl__
  547. def _set_pipeline_stage(stage):
  548. global _current_pipeline_stage
  549. _current_pipeline_stage = stage
  550. # NOTE(zhiqiu): This decorator is used for the APIs of Variable which is only
  551. # used to make Variable and Tensor has same interfaces, like numpy. Since Tensor is not exposed in our
  552. # official documents, logically, we want to keep Tensor and logically consistent. While, actually,
  553. # in our implementation, there some APIs not supported, like numpy, because Variable contains the desc.
  554. # So, those APIs are listed under class Variable to generate docs only.
  555. # TODO(zhiqiu): We should make Tensor consistent with Variable in future, for example, by inheriting
  556. # same base class.
  557. def _fake_interface_only_(func):
  558. def __impl__(*args, **kwargs):
  559. raise AssertionError(
  560. f"'{func.__name__}' only can be called by `paddle.Tensor` in dynamic graph mode. Suggestions:\n"
  561. " 1. If you are in static graph mode, you can switch to dynamic graph mode by turning off `paddle.enable_static()` or calling `paddle.disable_static()`.\n"
  562. " 2. If you are using `@paddle.jit.to_static`, you can call `paddle.jit.enable_to_static(False)`. "
  563. f"If you have to translate dynamic graph to static graph, please use other API to replace '{func.__name__}'."
  564. )
  565. return __impl__
  566. # NOTE(chenweihang): There is argument name typo (stat_dict, correct name is state_dict)
  567. # in base api Layer.set_dict, Optimizer.load, in order to correct the argument without
  568. # introducing compatibility issues, add this decorator
  569. # NOTE(chenweihang): not using `wrap_decorator` here is because `wrap_decorator` will
  570. # move kwargs to args, which doesn't work in this decorate case
  571. def deprecate_stat_dict(func):
  572. @functools.wraps(func)
  573. def wrapper(*args, **kwargs):
  574. if "stat_dict" in kwargs:
  575. warnings.warn(
  576. "The argument `stat_dict` has deprecated, please change it to `state_dict`.",
  577. DeprecationWarning,
  578. )
  579. kwargs["state_dict"] = kwargs["stat_dict"]
  580. kwargs.pop("stat_dict")
  581. return func(*args, **kwargs)
  582. return wrapper
  583. dygraph_not_support = wrap_decorator(_dygraph_not_support_)
  584. dygraph_only = wrap_decorator(_dygraph_only_)
  585. static_only = wrap_decorator(_static_only_)
  586. fake_interface_only = wrap_decorator(_fake_interface_only_)
  587. non_static_only = wrap_decorator(_non_static_only_)
  588. def _dygraph_tracer():
  589. return global_var._dygraph_tracer_
  590. def _current_expected_place_():
  591. global _global_expected_place_
  592. if _global_expected_place_ is None:
  593. if core.is_compiled_with_cuda():
  594. try:
  595. device_count = core.get_cuda_device_count()
  596. except Exception as e:
  597. device_count = 0
  598. if device_count > 0:
  599. _global_expected_place_ = core.CUDAPlace(_cuda_ids()[0])
  600. else:
  601. warnings.warn(
  602. "You are using GPU version Paddle, but your CUDA device is not set properly. CPU device will be used by default."
  603. )
  604. _global_expected_place_ = core.CPUPlace()
  605. elif core.is_compiled_with_xpu():
  606. try:
  607. device_count = core.get_xpu_device_count()
  608. except Exception as e:
  609. device_count = 0
  610. if device_count > 0:
  611. _global_expected_place_ = core.XPUPlace(_xpu_ids()[0])
  612. else:
  613. warnings.warn(
  614. "You are using XPU version Paddle, but your XPU device is not set properly. CPU device will be used by default."
  615. )
  616. _global_expected_place_ = core.CPUPlace()
  617. elif len(core.get_all_custom_device_type()) > 0:
  618. dev_type = core.get_all_custom_device_type()[0]
  619. try:
  620. device_count = core.get_custom_device_count(dev_type)
  621. except Exception as e:
  622. device_count = 0
  623. if device_count > 0:
  624. _global_expected_place_ = core.CustomPlace(
  625. dev_type, _custom_device_ids(dev_type)[0]
  626. )
  627. else:
  628. warnings.warn(
  629. "You are using CUSTOM_DEVICE version Paddle, but your custom device is not set properly. CPU device will be used by default."
  630. )
  631. _global_expected_place_ = core.CPUPlace()
  632. else:
  633. _global_expected_place_ = core.CPUPlace()
  634. return _global_expected_place_
  635. def _current_expected_place():
  636. if in_pir_mode():
  637. return core.Place()
  638. return _current_expected_place_()
  639. def _set_dygraph_tracer_expected_place(place):
  640. if global_var._dygraph_tracer_ is not None:
  641. global_var._dygraph_tracer_._expected_place = place
  642. def _set_expected_place(place):
  643. global _global_expected_place_
  644. _global_expected_place_ = place
  645. _set_dygraph_tracer_expected_place(place)
  646. def _cpu_num():
  647. if "CPU_NUM" not in os.environ.keys():
  648. if multiprocessing.cpu_count() > 1:
  649. sys.stderr.write(
  650. "!!! The CPU_NUM is not specified, you should set CPU_NUM in the environment variable list.\n"
  651. "CPU_NUM indicates that how many CPUPlace are used in the current task.\n"
  652. "And if this parameter are set as N (equal to the number of physical CPU core) the program may be faster.\n\n"
  653. f"export CPU_NUM={multiprocessing.cpu_count()} # for example, set CPU_NUM as number of physical CPU core which is {multiprocessing.cpu_count()}.\n\n"
  654. "!!! The default number of CPU_NUM=1.\n"
  655. )
  656. os.environ["CPU_NUM"] = str(1)
  657. cpu_num = os.environ.get("CPU_NUM")
  658. return int(cpu_num)
  659. def _cuda_ids():
  660. gpus_env = os.getenv("FLAGS_selected_gpus")
  661. if gpus_env:
  662. device_ids = [int(s) for s in gpus_env.split(",")]
  663. else:
  664. device_ids = range(core.get_cuda_device_count())
  665. return device_ids
  666. def _xpu_ids():
  667. xpus_env = os.getenv("FLAGS_selected_xpus")
  668. if xpus_env:
  669. device_ids = [int(s) for s in xpus_env.split(",")]
  670. else:
  671. device_ids = range(core.get_xpu_device_count())
  672. return device_ids
  673. def _custom_device_ids(device_type):
  674. custom_devices_env = os.getenv("FLAGS_selected_" + device_type + "s")
  675. if custom_devices_env:
  676. device_ids = [int(s) for s in custom_devices_env.split(",")]
  677. else:
  678. device_ids = range(core.get_custom_device_count(device_type))
  679. return device_ids
  680. def is_compiled_with_xpu():
  681. """
  682. Whether this whl package can be used to run the model on XPU.
  683. Returns (bool): support xpu or not.
  684. Examples:
  685. .. code-block:: python
  686. >>> import paddle.base as base
  687. >>> support_xpu = base.is_compiled_with_xpu()
  688. """
  689. return core.is_compiled_with_xpu()
  690. def disable_signal_handler():
  691. """
  692. Reset signal handler registered by Paddle.
  693. Paddle installs signal handlers at C++ level to log debug information upon failing.
  694. However, conflicts can happen if another python module is making use of such signal.
  695. Such being the case, one may disable paddle signal handler via this interface.
  696. Known frameworks that require disabling signal handler includes:
  697. 1. TVM
  698. 2. ADLIK
  699. Make sure you called paddle.disable_signal_handler() before using above mentioned frameworks.
  700. Returns:
  701. None
  702. Examples:
  703. .. code-block:: python
  704. >>> import paddle
  705. >>> paddle.disable_signal_handler()
  706. """
  707. core.disable_signal_handler()
  708. def is_compiled_with_cinn():
  709. """
  710. Whether this whl package can be used to run the model on CINN.
  711. Returns:
  712. Bool: `True` if CINN is currently available, otherwise `False`.
  713. Examples:
  714. .. code-block:: python
  715. >>> import paddle
  716. >>> support_cinn = paddle.device.is_compiled_with_cinn()
  717. """
  718. return core.is_compiled_with_cinn()
  719. def is_compiled_with_cuda():
  720. """
  721. Whether this whl package can be used to run the model on GPU.
  722. Returns:
  723. Bool: `True` if CUDA is currently available, otherwise `False`.
  724. Examples:
  725. .. code-block:: python
  726. >>> import paddle
  727. >>> support_gpu = paddle.device.is_compiled_with_cuda()
  728. """
  729. return core.is_compiled_with_cuda()
  730. def is_compiled_with_distribute():
  731. """
  732. Whether this whl package can be used to run the model with distribute.
  733. Returns:
  734. Bool: `True` if distribute is currently available, otherwise `False`.
  735. Examples:
  736. .. code-block:: python
  737. >>> import paddle
  738. >>> support_distribute = paddle.device.is_compiled_with_distribute()
  739. """
  740. return core.is_compiled_with_distribute()
  741. def is_compiled_with_rocm():
  742. """
  743. Whether this whl package can be used to run the model on AMD or Hygon GPU(ROCm).
  744. Returns:
  745. Bool: `True` if ROCm is currently available, otherwise `False`.
  746. Examples:
  747. .. code-block:: python
  748. >>> import paddle
  749. >>> support_gpu = paddle.device.is_compiled_with_rocm()
  750. """
  751. return core.is_compiled_with_rocm()
  752. def cuda_places(device_ids=None):
  753. """
  754. Note:
  755. For multi-card tasks, please use `FLAGS_selected_gpus` environment variable to set the visible GPU device.
  756. The next version will fix the problem with `CUDA_VISIBLE_DEVICES` environment variable.
  757. This function creates a list of :code:`paddle.CUDAPlace` objects.
  758. If :code:`device_ids` is None, environment variable of
  759. :code:`FLAGS_selected_gpus` would be checked first. For example, if
  760. :code:`FLAGS_selected_gpus=0,1,2`, the returned list would
  761. be [paddle.CUDAPlace(0), paddle.CUDAPlace(1), paddle.CUDAPlace(2)].
  762. If :code:`FLAGS_selected_gpus` is not set, all visible
  763. gpu places would be returned according to the :code:`CUDA_VISIBLE_DEVICES` environment variable.
  764. If :code:`device_ids` is not None, it should be the device
  765. ids of GPUs. For example, if :code:`device_ids=[0,1,2]`,
  766. the returned list would be
  767. [paddle.CUDAPlace(0), paddle.CUDAPlace(1), paddle.CUDAPlace(2)].
  768. Parameters:
  769. device_ids (list|tuple, optional): A list/tuple of int of GPU device ids.
  770. Returns:
  771. list of paddle.CUDAPlace: Created GPU place list.
  772. Examples:
  773. .. code-block:: python
  774. >>> # doctest: +REQUIRES(env:GPU)
  775. >>> import paddle
  776. >>> import paddle.static as static
  777. >>> paddle.device.set_device('gpu')
  778. >>> paddle.enable_static()
  779. >>> cuda_places = static.cuda_places()
  780. """
  781. assert core.is_compiled_with_cuda(), "Not compiled with CUDA"
  782. if device_ids is None:
  783. device_ids = _cuda_ids()
  784. elif not isinstance(device_ids, (list, tuple)):
  785. device_ids = [device_ids]
  786. return [core.CUDAPlace(dev_id) for dev_id in device_ids]
  787. def xpu_places(device_ids=None):
  788. """
  789. **Note**:
  790. For multi-card tasks, please use `FLAGS_selected_xpus` environment variable to set the visible XPU device.
  791. This function creates a list of :code:`paddle.XPUPlace` objects.
  792. If :code:`device_ids` is None, environment variable of
  793. :code:`FLAGS_selected_xpus` would be checked first. For example, if
  794. :code:`FLAGS_selected_xpus=0,1,2`, the returned list would
  795. be [paddle.XPUPlace(0), paddle.XPUPlace(1), paddle.XPUPlace(2)].
  796. If :code:`FLAGS_selected_xpus` is not set, all visible
  797. xpu places would be returned.
  798. If :code:`device_ids` is not None, it should be the device
  799. ids of XPUs. For example, if :code:`device_ids=[0,1,2]`,
  800. the returned list would be
  801. [paddle.XPUPlace(0), paddle.XPUPlace(1), paddle.XPUPlace(2)].
  802. Parameters:
  803. device_ids (list or tuple of int, optional): list of XPU device ids.
  804. Returns:
  805. list of paddle.XPUPlace: Created XPU place list.
  806. Examples:
  807. .. code-block:: python
  808. >>> # doctest: +REQUIRES(env:XPU)
  809. >>> import paddle
  810. >>> import paddle.static as static
  811. >>> paddle.device.set_device('xpu')
  812. >>> paddle.enable_static()
  813. >>> xpu_places = static.xpu_places()
  814. """
  815. assert core.is_compiled_with_xpu(), "Not compiled with XPU"
  816. if device_ids is None:
  817. device_ids = _xpu_ids()
  818. elif not isinstance(device_ids, (list, tuple)):
  819. device_ids = [device_ids]
  820. return [core.XPUPlace(dev_id) for dev_id in device_ids]
  821. def cpu_places(device_count=None):
  822. """
  823. This function creates a list of :code:`paddle.CPUPlace` objects, and returns the created list.
  824. If :code:`device_count` is None, the device count would
  825. be determined by environment variable :code:`CPU_NUM`.
  826. If :code:`CPU_NUM` is not set, the default value is 1,
  827. i.e. CPU_NUM=1.
  828. :code:`CPU_NUM` indicates the number of devices used in the current task.
  829. The running of the program can be accelerated if :code:`CPU_NUM` is the same as the number of physical cores.
  830. Parameters:
  831. device_count (int, optional): device number. Default: None.
  832. Returns:
  833. list of paddle.CPUPlace: Created list of CPU places.
  834. Examples:
  835. .. code-block:: python
  836. >>> import paddle
  837. >>> import paddle.static as static
  838. >>> paddle.enable_static()
  839. >>> cpu_places = static.cpu_places()
  840. """
  841. if device_count is None:
  842. device_count = _cpu_num()
  843. return [core.CPUPlace()] * device_count
  844. def cuda_pinned_places(device_count=None):
  845. """
  846. This function creates a list of :code:`base.CUDAPinnedPlace` objects.
  847. If :code:`device_count` is None, the device count would
  848. be determined by environment variable :code:`CPU_NUM`.
  849. If :code:`CPU_NUM` is not set, the default value is 1,
  850. i.e. CPU_NUM=1.
  851. :code:`CPU_NUM` indicates the number of devices used in the current task.
  852. The running of the program can be accelerated if :code:`CPU_NUM` is the same as the number of physical cores.
  853. Parameters:
  854. device_count (int, optional): device number. Default: None.
  855. Returns:
  856. list of base.CUDAPinnedPlace: Created list of CUDA pinned places.
  857. Examples:
  858. .. code-block:: python
  859. >>> # doctest: +REQUIRES(env:GPU)
  860. >>> import paddle.base as base
  861. >>> cuda_pinned_places_cpu_num = base.cuda_pinned_places()
  862. >>> # or
  863. >>> cuda_pinned_places = base.cuda_pinned_places(1)
  864. """
  865. assert core.is_compiled_with_cuda(), "Not compiled with CUDA"
  866. if device_count is None:
  867. device_count = len(_cuda_ids())
  868. return [core.CUDAPinnedPlace()] * device_count
  869. class NameScope:
  870. def __init__(self, name="", parent=None):
  871. self._children = {}
  872. self._name = name
  873. self._parent = parent
  874. def child(self, prefix):
  875. if prefix not in self._children:
  876. new_child = NameScope(prefix, self)
  877. self._children[prefix] = [new_child]
  878. else:
  879. new_child = NameScope(
  880. prefix + "_%d" % len(self._children[prefix]), self
  881. )
  882. self._children[prefix].append(new_child)
  883. return new_child
  884. def parent(self):
  885. return self._parent
  886. def name(self):
  887. return self._name
  888. _name_scope = NameScope()
  889. @signature_safe_contextmanager
  890. def name_scope(prefix=None):
  891. """
  892. Generate hierarchical name prefix for the operators in Static Graph.
  893. Note:
  894. This should only used for debugging and visualization purpose.
  895. Don't use it for serious analysis such as graph/program transformations.
  896. Don't use it in dygraph, since it will cause memory leak.
  897. Args:
  898. prefix(str, optional): prefix. Default is none.
  899. Examples:
  900. .. code-block:: python
  901. >>> import paddle
  902. >>> paddle.enable_static()
  903. >>> with paddle.static.name_scope("s1"):
  904. ... a = paddle.static.data(name='data', shape=[None, 1], dtype='int32')
  905. ... b = a + paddle.to_tensor(1)
  906. ... with paddle.static.name_scope("s2"):
  907. ... c = b * paddle.to_tensor(1)
  908. ... with paddle.static.name_scope("s3"):
  909. ... d = c / paddle.to_tensor(1)
  910. >>> with paddle.static.name_scope("s1"):
  911. ... f = paddle.tensor.pow(d, paddle.to_tensor(2.0))
  912. >>> with paddle.static.name_scope("s4"):
  913. ... g = f - paddle.to_tensor(1)
  914. >>> # Op are created in the default main program.
  915. >>> for op in paddle.static.default_main_program().block(0).ops:
  916. ... # elementwise_add is created in /s1/
  917. ... if op.type == 'elementwise_add':
  918. ... assert op.desc.attr("op_namescope") == '/s1/'
  919. ... # elementwise_mul is created in '/s1/s2'
  920. ... elif op.type == 'elementwise_mul':
  921. ... assert op.desc.attr("op_namescope") == '/s1/s2/'
  922. ... # elementwise_div is created in '/s1/s3'
  923. ... elif op.type == 'elementwise_div':
  924. ... assert op.desc.attr("op_namescope") == '/s1/s3/'
  925. ... # elementwise_sum is created in '/s4'
  926. ... elif op.type == 'elementwise_sub':
  927. ... assert op.desc.attr("op_namescope") == '/s4/'
  928. ... # pow is created in /s1_1/
  929. ... elif op.type == 'pow':
  930. ... assert op.desc.attr("op_namescope") == '/s1_1/'
  931. """
  932. # TODO(panyx0718): Only [0-9a-z].
  933. # in dygraph we don't need namescope since it will cause mem leak
  934. if in_dygraph_mode():
  935. yield
  936. else:
  937. assert prefix, "namescope prefix can not be empty."
  938. global _name_scope
  939. _name_scope = _name_scope.child(prefix)
  940. try:
  941. yield
  942. finally:
  943. _name_scope = _name_scope.parent()
  944. class NameStruct:
  945. def __init__(self, name="", parent=None):
  946. self._children = {}
  947. self._name = name
  948. self._parent = parent
  949. def child(self, prefix):
  950. if prefix not in self._children:
  951. new_child = NameStruct(prefix, self)
  952. self._children[prefix] = [new_child]
  953. else:
  954. new_child = NameStruct(
  955. prefix + "_%d" % len(self._children[prefix]), self
  956. )
  957. self._children[prefix].append(new_child)
  958. return new_child
  959. def parent(self):
  960. return self._parent
  961. def name(self):
  962. return self._name
  963. _name_struct = NameStruct()
  964. @signature_safe_contextmanager
  965. def name_struct(prefix=None):
  966. """
  967. Note: This should only used in Paddle/python/paddle/nn/layer/layers.py
  968. to record the call path for the operators in Static Graph of AutoParallel.
  969. Args:
  970. prefix(str, optional): prefix. Default is none.
  971. """
  972. # TODO(panyx0718): Only [0-9a-z].
  973. # in dygraph we don't need namescope since it will cause mem leak
  974. if in_dygraph_mode():
  975. yield
  976. else:
  977. assert prefix, "namescope prefix can not be empty."
  978. global _name_struct
  979. _name_struct = _name_struct.child(prefix)
  980. try:
  981. yield
  982. finally:
  983. _name_struct = _name_struct.parent()
  984. def _full_name_struct():
  985. global _name_struct
  986. struct = _name_struct
  987. name = ""
  988. while struct:
  989. name = struct.name() + "/" + name
  990. struct = struct.parent()
  991. return name
  992. def _full_name_scope():
  993. global _name_scope
  994. scope = _name_scope
  995. name = ""
  996. while scope:
  997. name = scope.name() + "/" + name
  998. scope = scope.parent()
  999. return name
  1000. def generate_control_dev_var_name():
  1001. import random
  1002. return CONTROL_DEP_VAR_PREFIX + "@" + str(random.random())
  1003. def grad_var_name(var_name):
  1004. """
  1005. Returns:
  1006. str: gradient name for a certain var name
  1007. """
  1008. return var_name + GRAD_VAR_SUFFIX
  1009. def convert_np_dtype_to_proto_type(np_dtype: np.dtype | str):
  1010. """
  1011. Convert the data type in numpy to the data type in Paddle.
  1012. Args:
  1013. np_dtype (np.dtype|str): The data type in numpy or valid data type
  1014. string.
  1015. Returns:
  1016. core.VarDesc.VarType : The data type in Paddle.
  1017. """
  1018. # Convert the data type string to numpy data type.
  1019. if isinstance(np_dtype, str) and np_dtype == "bfloat16":
  1020. dtype = np.uint16
  1021. else:
  1022. dtype = np.dtype(np_dtype)
  1023. if dtype == np.float32:
  1024. return core.VarDesc.VarType.FP32
  1025. elif dtype == np.float64:
  1026. return core.VarDesc.VarType.FP64
  1027. elif dtype == np.float16:
  1028. return core.VarDesc.VarType.FP16
  1029. elif dtype == np.int32:
  1030. return core.VarDesc.VarType.INT32
  1031. elif dtype == np.int16:
  1032. return core.VarDesc.VarType.INT16
  1033. elif dtype == np.int64:
  1034. return core.VarDesc.VarType.INT64
  1035. elif dtype == np.bool_:
  1036. return core.VarDesc.VarType.BOOL
  1037. elif dtype == np.uint16:
  1038. # since there is still no support for bfloat16 in NumPy,
  1039. # uint16 is used for casting bfloat16
  1040. return core.VarDesc.VarType.BF16
  1041. elif dtype == np.uint8:
  1042. return core.VarDesc.VarType.UINT8
  1043. elif dtype == np.int8:
  1044. return core.VarDesc.VarType.INT8
  1045. elif dtype == np.complex64:
  1046. return core.VarDesc.VarType.COMPLEX64
  1047. elif dtype == np.complex128:
  1048. return core.VarDesc.VarType.COMPLEX128
  1049. else:
  1050. raise ValueError("Not supported numpy dtype %s" % dtype)
  1051. def convert_np_dtype_to_dtype_(np_dtype):
  1052. """
  1053. Convert the data type in numpy to the data type in Paddle.
  1054. Args:
  1055. np_dtype (np.dtype|str): The data type in numpy or valid data type
  1056. string.
  1057. Returns:
  1058. core.VarDesc.VarType / core.DataType : The data type in Paddle.
  1059. """
  1060. if use_pir_api():
  1061. return pir.core.convert_np_dtype_to_dtype_(np_dtype)
  1062. return convert_np_dtype_to_proto_type(np_dtype)
  1063. def convert_to_proto_type(dtype):
  1064. """
  1065. Convert the data type in numpy to the data type in Paddle.
  1066. Args:
  1067. dtype (np.dtype|str|core.DataType|core.VarDesc.VarType): The data type in numpy, valid data type
  1068. string or paddle dtype.
  1069. Returns:
  1070. core.VarDesc.VarType : The data type in Paddle.
  1071. """
  1072. if isinstance(dtype, core.VarDesc.VarType):
  1073. return dtype
  1074. elif isinstance(dtype, core.DataType):
  1075. return paddle_type_to_proto_type[dtype]
  1076. else:
  1077. return convert_np_dtype_to_proto_type(dtype)
  1078. def dtype_is_floating(dtype):
  1079. """
  1080. Check the data type is floating or not.
  1081. Args:
  1082. dtype(np.dtype|core.VarDesc.VarType): data type.
  1083. Could be numpy format or Paddle format
  1084. Returns(bool): True if data type is a float value
  1085. """
  1086. if not isinstance(dtype, core.VarDesc.VarType):
  1087. dtype = convert_np_dtype_to_dtype_(dtype)
  1088. return dtype in [
  1089. core.VarDesc.VarType.FP16,
  1090. core.VarDesc.VarType.FP32,
  1091. core.VarDesc.VarType.FP64,
  1092. ]
  1093. def _debug_string_(proto, throw_on_error=True):
  1094. """
  1095. Get the debug string of a protobuf message. The message could be not
  1096. initialized.
  1097. Args:
  1098. proto(google.protobuf.message.Message): The protobuf message
  1099. throw_on_error(bool): True if raise an error when the protobuf message
  1100. is not initialized.
  1101. Returns(str): The debug string of the protobuf message
  1102. """
  1103. error_fields = []
  1104. if not proto.IsInitialized(error_fields) and throw_on_error:
  1105. raise ValueError(
  1106. f"{error_fields} are not initialized.\nThe message is {proto}:\n"
  1107. )
  1108. return proto.__str__()
  1109. def _create_tensor(
  1110. type=core.VarDesc.VarType.LOD_TENSOR,
  1111. name=None,
  1112. shape=None,
  1113. dtype=None,
  1114. persistable=None,
  1115. **kwargs,
  1116. ):
  1117. if dtype is not None:
  1118. dtype = convert_to_proto_type(dtype)
  1119. else:
  1120. dtype = core.VarDesc.VarType.FP32
  1121. eager_tensor = core.eager.Tensor(
  1122. dtype,
  1123. list(shape) if shape else [],
  1124. name,
  1125. type if type else core.VarDesc.VarType.LOD_TENSOR,
  1126. True if persistable else False,
  1127. )
  1128. eager_tensor.retain_grads()
  1129. return eager_tensor
  1130. def _all_is_type(vals, expected_type):
  1131. """
  1132. Return True if type of each element is expected_type.
  1133. NOTE: BuiltIn all() will always return True if vals is empty.
  1134. """
  1135. assert isinstance(vals, (list, tuple))
  1136. if not vals:
  1137. return False
  1138. return all(isinstance(v, expected_type) for v in vals)
  1139. def wrap_as_scalar(number):
  1140. """Wrap a number(either python scalar or numpy scalar) as core.Scalar if
  1141. it is not a scalar.
  1142. Args:
  1143. number (Number): number
  1144. Returns:
  1145. Scalar: A Scalar that contains the value.
  1146. """
  1147. if isinstance(number, core.Scalar):
  1148. return number
  1149. if isinstance(number, (bool, int, float, complex)):
  1150. return core.Scalar(number)
  1151. if isinstance(number, np.number):
  1152. # it is a numpy scalar
  1153. return core.Scalar(number.item())
  1154. else:
  1155. raise TypeError(f"Cannot wrap {number} as core.Scalar")
  1156. def wrap_as_scalars(array):
  1157. """This function is used to convert flat list, or numpy array(not
  1158. necessarily flat) to list of core.Scalar, which correspond to
  1159. std::vector<paddle::experimental::Scalar> in operator runtime.
  1160. Args:
  1161. array (List | np.ndarray): array of numbers
  1162. Returns:
  1163. List: list of core.Scalar, of which each element is a Scalar containing
  1164. the corresponding value.
  1165. """
  1166. if isinstance(array, np.ndarray):
  1167. array = array.ravel().tolist()
  1168. return [wrap_as_scalar(item) for item in array]
  1169. def extract_plain_list(array):
  1170. """extract value from a list of core.Scalar.
  1171. Args:
  1172. array (list): Scalars
  1173. Returns:
  1174. list: values extracted from the scalars.
  1175. """
  1176. return [item.value() for item in array]
  1177. def canonicalize_attrs(attrs, op_proto):
  1178. """This function is used to canonicalize attributes(as a string->any dict)
  1179. according to the type specification in the OpProto. This is especially
  1180. important for operators that has any attributes of type Scalar or Scalars.
  1181. Though various frontends of phi kernels & paddle operators can wrap variables
  1182. of concrete types into Scalars(a tagged union of several numeric types) or
  1183. vector of Scalars. Paddle operator requires strict type matching.
  1184. Args:
  1185. attrs (Dict[str, Any]): attribute dict intended to pass to an operator.
  1186. op_proto (OpProto): Proto (signature) of the operator.
  1187. Returns:
  1188. Dict[str, Any]: canonicalized attributes.
  1189. """
  1190. canonicalized_attrs = attrs.copy() # shallow copy is enough here
  1191. for attr in op_proto.attrs:
  1192. attr_name = attr.name
  1193. type_index = attr.type
  1194. if (attr_name not in attrs) or (attrs[attr_name] is None):
  1195. continue
  1196. attr_val = attrs[attr_name]
  1197. # VAR and VARS should be skipped
  1198. if isinstance(attr_val, Variable):
  1199. continue
  1200. if isinstance(attr_val, list) and _all_is_type(attr_val, Variable):
  1201. continue
  1202. # wrap
  1203. if type_index == core.AttrType.SCALAR:
  1204. canonicalized_attrs[attr_name] = core.Scalar(attr_val)
  1205. elif type_index == core.AttrType.SCALARS:
  1206. # it should be a list (or a numpy array)
  1207. if len(attr_val) > 0:
  1208. attr_val = np.array(attr_val).ravel().tolist()
  1209. attr_val = [core.Scalar(x) for x in attr_val]
  1210. canonicalized_attrs[attr_name] = attr_val
  1211. return canonicalized_attrs
  1212. class VariableMetaClass(type):
  1213. @classmethod
  1214. def __instancecheck__(cls, instance):
  1215. t = type(instance)
  1216. if in_dygraph_mode():
  1217. return issubclass(t, core.eager.Tensor)
  1218. else:
  1219. return issubclass(t, Variable)
  1220. class ParameterMetaClass(VariableMetaClass):
  1221. @classmethod
  1222. def __instancecheck__(cls, instance):
  1223. t = type(instance)
  1224. if in_dygraph_mode():
  1225. return issubclass(t, EagerParamBase)
  1226. else:
  1227. return issubclass(t, Parameter)
  1228. class Variable(metaclass=VariableMetaClass):
  1229. """
  1230. Notes:
  1231. The constructor of Variable should not be invoked directly.
  1232. In Static Graph Mode: Please use ** `Block.create_var` ** to create a Static variable which has no data until being feed.
  1233. In Dygraph Mode: Please use ** :ref:`api_paddle_to_tensor` ** to create a dygraph variable with real data.
  1234. In Fluid, every input and output of an OP is a variable. In most
  1235. cases, variables are used for holding different kinds of data or training
  1236. labels. A variable belongs to a :ref:`api_guide_Block_en` . All variable has its own name and
  1237. two variables in different :ref:`api_guide_Block_en` could have the same name.
  1238. There are many kinds of variables. Each kind of them has its own attributes
  1239. and usages. Please refer to the `framework.proto <https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/base/framework/framework.proto>`_ for details.
  1240. Most of a Variable's member variables can be set to be None. It mean
  1241. it is not available or will be specified later.
  1242. Examples:
  1243. In Static Graph Mode:
  1244. .. code-block:: python
  1245. :name: code-example-1
  1246. >>> import paddle.base as base
  1247. >>> cur_program = base.Program()
  1248. >>> cur_block = cur_program.current_block()
  1249. >>> new_variable = cur_block.create_var(name="X",
  1250. ... shape=[-1, 23, 48],
  1251. ... dtype='float32')
  1252. In Dygraph Mode:
  1253. .. code-block:: python
  1254. :name: code-example-2
  1255. >>> import paddle.base as base
  1256. >>> import numpy as np
  1257. >>> import paddle
  1258. >>> with base.dygraph.guard():
  1259. ... new_variable = paddle.to_tensor(np.arange(10))
  1260. """
  1261. def __init__(
  1262. self,
  1263. block,
  1264. type=core.VarDesc.VarType.LOD_TENSOR,
  1265. name=None,
  1266. shape=None,
  1267. dtype=None,
  1268. lod_level=None,
  1269. capacity=None,
  1270. persistable=None,
  1271. error_clip=None,
  1272. stop_gradient=False,
  1273. is_data=False,
  1274. need_check_feed=False,
  1275. belong_to_optimizer=False,
  1276. **kwargs,
  1277. ):
  1278. self.block = block
  1279. if name is None:
  1280. name = self.block.program._name_generator("_generated_var")
  1281. while self.block._find_var_recursive(name) is not None:
  1282. name = self.block.program._name_generator("_generated_var")
  1283. if dtype is not None:
  1284. dtype = convert_to_proto_type(dtype)
  1285. if dtype == core.VarDesc.VarType.STRINGS:
  1286. type = core.VarDesc.VarType.STRINGS
  1287. lod_level = None
  1288. if type == core.VarDesc.VarType.SPARSE_COO:
  1289. lod_level = None
  1290. self.belong_to_optimizer = belong_to_optimizer
  1291. self.error_clip = error_clip
  1292. is_new_var = False
  1293. self.desc = self.block.desc.find_var(name.encode())
  1294. if self.desc is None:
  1295. self.desc = self.block.desc.var(name.encode())
  1296. is_new_var = True
  1297. if is_new_var:
  1298. self.desc.set_type(type)
  1299. elif self.desc.type() != type:
  1300. raise ValueError(
  1301. f"Variable '{self.name}' has been created before. The "
  1302. f"previous type is {self.desc.type()}, the new type is {type}. They"
  1303. " are not matched"
  1304. )
  1305. if shape is not None:
  1306. if is_new_var:
  1307. self.desc.set_shape(shape)
  1308. else:
  1309. old_shape = self.shape
  1310. shape = tuple(shape)
  1311. if shape != old_shape:
  1312. raise ValueError(
  1313. f"Variable '{self.name}' has been created before. The previous "
  1314. f"shape is {old_shape}, the new shape is {shape}. They are not "
  1315. "matched."
  1316. )
  1317. if dtype is not None:
  1318. if is_new_var:
  1319. self.desc.set_dtype(dtype)
  1320. else:
  1321. old_dtype = self.dtype
  1322. if dtype != old_dtype:
  1323. raise ValueError(
  1324. f"Variable '{self.name}' has been created before. "
  1325. f"The previous data type is {old_dtype}, the new "
  1326. f"data type is {dtype}. They are not "
  1327. "matched."
  1328. )
  1329. if lod_level is not None:
  1330. if is_new_var:
  1331. self.desc.set_lod_level(lod_level)
  1332. else:
  1333. if lod_level != self.lod_level:
  1334. raise ValueError(
  1335. f"Variable '{self.name}' has been created before. "
  1336. f"The previous lod_level is {self.lod_level}, the new "
  1337. f"lod_level is {lod_level}. They are not "
  1338. "matched"
  1339. )
  1340. if persistable is not None:
  1341. if is_new_var:
  1342. self.desc.set_persistable(persistable)
  1343. else:
  1344. if persistable != self.persistable:
  1345. raise ValueError(
  1346. f"Variable '{self.name}' has been created before."
  1347. f"The previous persistable is {self.persistable}, the new "
  1348. f"persistable is {persistable}. They are not matched"
  1349. )
  1350. if need_check_feed and is_new_var:
  1351. self.desc.set_need_check_feed(need_check_feed)
  1352. if capacity is not None:
  1353. if is_new_var:
  1354. self.desc.set_capacity(capacity)
  1355. else:
  1356. # TODO(abhinavarora) : Compare with set capacity once,
  1357. # get_capacity is implemented
  1358. pass
  1359. self.block.vars[name] = self
  1360. self.op = None
  1361. self.stop_gradient = stop_gradient
  1362. self.is_data = is_data
  1363. self.is_view_var = False
  1364. def detach(self):
  1365. """
  1366. Returns a new Variable, detached from the current graph.
  1367. It will share data with origin Variable and without tensor copy.
  1368. In addition, the detached Variable doesn't provide gradient propagation.
  1369. Returns:
  1370. ( :ref:`api_guide_Variable_en` | dtype is same as current Variable), The detached Variable.
  1371. Examples:
  1372. .. code-block:: python
  1373. >>> import paddle
  1374. >>> paddle.enable_static()
  1375. >>> # create a static Variable
  1376. >>> x = paddle.static.data(name='x', shape=[3, 2, 1])
  1377. >>> # create a detached Variable
  1378. >>> y = x.detach()
  1379. """
  1380. assert (
  1381. self.type == core.VarDesc.VarType.SELECTED_ROWS
  1382. or self.type == core.VarDesc.VarType.LOD_TENSOR
  1383. ), "only support a variable with SELECTED_ROWS or LOD_TENSOR to be detached"
  1384. with unique_name.guard(self.block.program._name_generator):
  1385. output = self.block.create_var(
  1386. name=unique_name.generate_with_ignorable_key(
  1387. "detach_" + self.name
  1388. ),
  1389. dtype=self.dtype,
  1390. type=self.type,
  1391. persistable=self.persistable,
  1392. stop_gradient=True,
  1393. )
  1394. self.block.append_op(
  1395. type="share_data",
  1396. inputs={"X": [self]},
  1397. outputs={"Out": [output]},
  1398. )
  1399. return output
  1400. @fake_interface_only
  1401. def numpy(self):
  1402. """
  1403. **Notes**:
  1404. **This API is ONLY available in Dygraph mode**
  1405. Returns a numpy array shows the value of current :ref:`api_guide_Variable_en`
  1406. Returns:
  1407. ndarray: The numpy value of current Variable.
  1408. Returns type:
  1409. ndarray: dtype is same as current Variable
  1410. Examples:
  1411. .. code-block:: python
  1412. >>> import paddle.base as base
  1413. >>> from paddle.nn import Linear
  1414. >>> import numpy as np
  1415. >>> data = np.random.uniform(-1, 1, [30, 10, 32]).astype('float32')
  1416. >>> with base.dygraph.guard():
  1417. ... linear = Linear(32, 64)
  1418. ... data = paddle.to_tensor(data)
  1419. ... x = linear(data)
  1420. ... print(x.numpy())
  1421. """
  1422. pass
  1423. @non_static_only
  1424. def backward(self, retain_graph=False):
  1425. """
  1426. **Notes**:
  1427. **This API is ONLY available in Dygraph mode**
  1428. Run backward of current Graph which starts from current Tensor.
  1429. Args:
  1430. retain_graph(bool, optional): If False, the graph used to compute grads will be freed. If you would
  1431. like to add more ops to the built graph after calling this method( :code:`backward` ), set the parameter
  1432. :code:`retain_graph` to True, then the grads will be retained. Thus, setting it to False is much more memory-efficient.
  1433. Defaults to False.
  1434. Returns:
  1435. NoneType: None
  1436. Examples:
  1437. .. code-block:: python
  1438. >>> import numpy as np
  1439. >>> import paddle
  1440. >>> paddle.disable_static()
  1441. >>> x = np.ones([2, 2], np.float32)
  1442. >>> inputs = []
  1443. >>> for _ in range(10):
  1444. ... tmp = paddle.to_tensor(x)
  1445. ... # if we don't set tmp's stop_gradient as False then, all path to loss will has no gradient since
  1446. ... # there is no one need gradient on it.
  1447. ... tmp.stop_gradient=False
  1448. ... inputs.append(tmp)
  1449. >>> ret = paddle.add_n(inputs)
  1450. >>> loss = paddle.sum(ret)
  1451. >>> loss.backward()
  1452. """
  1453. from .backward import append_backward
  1454. if retain_graph is True:
  1455. raise AssertionError(
  1456. "`retain_graph` == True is not supported in @to_static function."
  1457. "please set retain_graph = False."
  1458. )
  1459. param_grad_list = append_backward(self)
  1460. for param, param_grad in param_grad_list:
  1461. # set grad to simulate dygraph loss.backward() in static mode.
  1462. param.grad = param_grad
  1463. @fake_interface_only
  1464. def gradient(self):
  1465. """
  1466. **Notes**:
  1467. **This API is ONLY available in Dygraph mode**
  1468. Get the Gradient of Current Variable
  1469. Returns:
  1470. ndarray or tuple of ndarray: if Variable's type is LoDTensor, return numpy value of the gradient of current Variable, if Variable's type is SelectedRows, return tuple of ndarray, first element of tuple is numpy value of the gradient of current Variable, second element of tuple is numpy value of the rows of current Variable.
  1471. Examples:
  1472. .. code-block:: python
  1473. >>> import paddle
  1474. >>> import paddle.base as base
  1475. >>> import numpy as np
  1476. >>> # example1: return ndarray
  1477. >>> x = np.ones([2, 2], np.float32)
  1478. >>> with base.dygraph.guard():
  1479. ... inputs2 = []
  1480. ... for _ in range(10):
  1481. ... tmp = paddle.to_tensor(x)
  1482. ... tmp.stop_gradient=False
  1483. ... inputs2.append(tmp)
  1484. ... ret2 = paddle.add_n(inputs2)
  1485. ... loss2 = paddle.sum(ret2)
  1486. ... loss2.retain_grads()
  1487. ... loss2.backward()
  1488. ... print(loss2.gradient())
  1489. >>> # example2: return tuple of ndarray
  1490. >>> with base.dygraph.guard():
  1491. ... embedding = paddle.nn.Embedding(
  1492. ... 20,
  1493. ... 32,
  1494. ... weight_attr='emb.w',
  1495. ... sparse=True)
  1496. ... x_data = np.arange(12).reshape(4, 3).astype('int64')
  1497. ... x_data = x_data.reshape((-1, 3, 1))
  1498. ... x = paddle.to_tensor(x_data)
  1499. ... out = embedding(x)
  1500. ... out.backward()
  1501. ... print(embedding.weight.gradient())
  1502. """
  1503. pass
  1504. @fake_interface_only
  1505. def clear_gradient(self):
  1506. """
  1507. **Notes**:
  1508. **1. This API is ONLY available in Dygraph mode**
  1509. **2. Use it only Variable has gradient, normally we use this for Parameters since other temporal Variable will be deleted by Python's GC**
  1510. Clear (set to ``0`` ) the Gradient of Current Variable
  1511. Returns: None
  1512. Examples:
  1513. .. code-block:: python
  1514. >>> import paddle
  1515. >>> import paddle.base as base
  1516. >>> import numpy as np
  1517. >>> x = np.ones([2, 2], np.float32)
  1518. >>> inputs2 = []
  1519. >>> for _ in range(10):
  1520. >>> tmp = paddle.to_tensor(x)
  1521. >>> tmp.stop_gradient=False
  1522. >>> inputs2.append(tmp)
  1523. >>> ret2 = paddle.add_n(inputs2)
  1524. >>> loss2 = paddle.sum(ret2)
  1525. >>> loss2.retain_grads()
  1526. >>> loss2.backward()
  1527. >>> print(loss2.gradient())
  1528. >>> loss2.clear_gradient()
  1529. >>> print("After clear {}".format(loss2.gradient()))
  1530. 1.0
  1531. After clear 0.0
  1532. """
  1533. pass
  1534. def register_hook(self, hook):
  1535. import paddle
  1536. def backward_hook_wrapper(dy):
  1537. """call the backward hook in ."""
  1538. return hook(np.array(dy))
  1539. def forward_hook_wrapper(x):
  1540. """do nothing but return a new variable."""
  1541. return x
  1542. paddle.static.py_func(
  1543. func=forward_hook_wrapper,
  1544. x=self,
  1545. out=self,
  1546. backward_func=backward_hook_wrapper,
  1547. skip_vars_in_backward_input=[self],
  1548. )
  1549. def apply(self, func):
  1550. if not self.stop_gradient:
  1551. raise RuntimeError(
  1552. "Cannot apply function on a tensor that required gradient."
  1553. )
  1554. try:
  1555. return func(self)
  1556. except:
  1557. raise ValueError(f"The PyFunc {func.__name__} could not be applied")
  1558. def __str__(self):
  1559. return self._to_readable_code()
  1560. def _to_readable_code(self):
  1561. """
  1562. Get readable debug string of Variable.
  1563. .. note::
  1564. If you want to get the debug string in protobuf format,
  1565. please use :code:`to_string` method.
  1566. Returns:
  1567. string: The formatted Variable string.
  1568. Examples:
  1569. .. code-block:: python
  1570. >>> import paddle
  1571. >>> import paddle.static as static
  1572. >>> paddle.enable_static()
  1573. >>> cur_program = static.Program()
  1574. >>> cur_block = cur_program.current_block()
  1575. >>> new_variable = cur_block.create_var(name="X",
  1576. ... shape=[-1, 23, 48],
  1577. ... dtype='float32')
  1578. >>> print(new_variable._to_readable_code())
  1579. var X : LOD_TENSOR.shape(-1, 23, 48).dtype(float32).stop_gradient(False)
  1580. """
  1581. # VarType.LOD_TENSOR -> LOD_TENSOR
  1582. type_str = str(self.type).split(".")[1]
  1583. if (
  1584. self.type == core.VarDesc.VarType.SELECTED_ROWS
  1585. or self.type == core.VarDesc.VarType.LOD_TENSOR
  1586. ):
  1587. dtype_str = str(self.dtype).split(".")[1]
  1588. var_str = f"{self.name} : {type_str}.shape{self.shape}.dtype({dtype_str}).stop_gradient({self.stop_gradient})"
  1589. else:
  1590. var_str = f"{self.name} : {type_str})"
  1591. if self.is_parameter:
  1592. if self.trainable:
  1593. var_str = "trainable param " + var_str
  1594. else:
  1595. var_str = "param " + var_str
  1596. else:
  1597. var_str = "var " + var_str
  1598. if self.persistable:
  1599. var_str = "persist " + var_str
  1600. from paddle.distributed.auto_parallel.static.dist_context import (
  1601. get_default_distributed_context,
  1602. )
  1603. dist_context = get_default_distributed_context()
  1604. dist_tensor = dist_context.get_dist_tensor_for_program(self)
  1605. if dist_tensor is not None:
  1606. var_str += ", {name} = {value}".format(
  1607. name="dist_attr", value=dist_tensor
  1608. )
  1609. return var_str
  1610. def to_string(self, throw_on_error, with_details=False):
  1611. """
  1612. Get debug string.
  1613. Args:
  1614. throw_on_error (bool): True if raise an exception when self is not initialized.
  1615. with_details (bool): more details about variables and parameters (e.g. trainable, optimize_attr, ...) will be printed when with_details is True. Default value is False;
  1616. Returns:
  1617. str: The debug string.
  1618. Examples:
  1619. .. code-block:: python
  1620. >>> import paddle.base as base
  1621. >>> import paddle
  1622. >>> paddle.enable_static()
  1623. >>> cur_program = base.Program()
  1624. >>> cur_block = cur_program.current_block()
  1625. >>> new_variable = cur_block.create_var(name="X",
  1626. ... shape=[-1, 23, 48],
  1627. ... dtype='float32')
  1628. >>> print(new_variable.to_string(True))
  1629. >>> print("=============with detail===============")
  1630. >>> print(new_variable.to_string(True, True))
  1631. name: "X"
  1632. type {
  1633. type: LOD_TENSOR
  1634. lod_tensor {
  1635. tensor {
  1636. data_type: FP32
  1637. dims: -1
  1638. dims: 23
  1639. dims: 48
  1640. }
  1641. }
  1642. }
  1643. stop_gradient: false
  1644. error_clip: None
  1645. """
  1646. assert isinstance(throw_on_error, bool) and isinstance(
  1647. with_details, bool
  1648. )
  1649. protostr = self.desc.serialize_to_string()
  1650. proto = framework_pb2.VarDesc.FromString(bytes(protostr))
  1651. res_str = _debug_string_(proto, throw_on_error)
  1652. if with_details:
  1653. additional_attr = ("error_clip",)
  1654. for attr_name in additional_attr:
  1655. res_str += f"{attr_name}: {getattr(self, attr_name)}\n"
  1656. return res_str
  1657. __repr__ = __str__
  1658. def element_size(self):
  1659. """
  1660. Returns the size in bytes of an element in the Tensor.
  1661. Examples:
  1662. .. code-block:: python
  1663. >>> import paddle
  1664. >>> paddle.enable_static()
  1665. >>> x = paddle.static.data(name='x1', shape=[3, 2], dtype='bool')
  1666. >>> print(x.element_size())
  1667. 1
  1668. >>> x = paddle.static.data(name='x2', shape=[3, 2], dtype='int16')
  1669. >>> print(x.element_size())
  1670. 2
  1671. >>> x = paddle.static.data(name='x3', shape=[3, 2], dtype='float16')
  1672. >>> print(x.element_size())
  1673. 2
  1674. >>> x = paddle.static.data(name='x4', shape=[3, 2], dtype='float32')
  1675. >>> print(x.element_size())
  1676. 4
  1677. >>> x = paddle.static.data(name='x5', shape=[3, 2], dtype='float64')
  1678. >>> print(x.element_size())
  1679. 8
  1680. """
  1681. return self.desc.element_size()
  1682. @property
  1683. def stop_gradient(self):
  1684. """
  1685. Indicating if we stop gradient from current Variable
  1686. **Notes: This Property has default value as** ``True`` **in** Dygraph **mode, while Parameter's default value is False. However, in Static Graph Mode all Variable's default stop_gradient value is** ``False``
  1687. Examples:
  1688. .. code-block:: python
  1689. >>> import paddle
  1690. >>> import paddle.base as base
  1691. >>> import numpy as np
  1692. >>> with base.dygraph.guard():
  1693. ... value0 = np.arange(26).reshape(2, 13).astype("float32")
  1694. ... value1 = np.arange(6).reshape(2, 3).astype("float32")
  1695. ... value2 = np.arange(10).reshape(2, 5).astype("float32")
  1696. ... linear = paddle.nn.Linear(13, 5)
  1697. ... linear2 = paddle.nn.Linear(3, 3)
  1698. ... a = paddle.to_tensor(value0)
  1699. ... b = paddle.to_tensor(value1)
  1700. ... c = paddle.to_tensor(value2)
  1701. ... out1 = linear(a)
  1702. ... out2 = linear2(b)
  1703. ... out1.stop_gradient = True
  1704. ... out = paddle.concat(x=[out1, out2, c], axis=1)
  1705. ... out.backward()
  1706. ... assert linear.weight.gradient() is None
  1707. ... assert out1.gradient() is None
  1708. """
  1709. return self.desc.stop_gradient()
  1710. @stop_gradient.setter
  1711. def stop_gradient(self, s):
  1712. self.desc.set_stop_gradient(s)
  1713. @property
  1714. def persistable(self):
  1715. """
  1716. Indicating if we current Variable should be long-term alive
  1717. **Notes: This Property will be deprecated and this API is just to help user understand concept**
  1718. **1. All Variable's persistable is** ``False`` **except Parameters.**
  1719. **2. In** Dygraph **mode, this property should not be changed**
  1720. Examples:
  1721. .. code-block:: python
  1722. >>> import paddle.base as base
  1723. >>> cur_program = base.Program()
  1724. >>> cur_block = cur_program.current_block()
  1725. >>> new_variable = cur_block.create_var(name="X",
  1726. ... shape=[-1, 23, 48],
  1727. ... dtype='float32')
  1728. >>> print("persistable of current Var is: {}".format(new_variable.persistable))
  1729. persistable of current Var is: False
  1730. """
  1731. return self.desc.persistable()
  1732. @persistable.setter
  1733. def persistable(self, p):
  1734. self.desc.set_persistable(p)
  1735. @property
  1736. def is_parameter(self):
  1737. """
  1738. Indicating if current Variable is a Parameter
  1739. Examples:
  1740. .. code-block:: python
  1741. >>> import paddle
  1742. >>> paddle.enable_static()
  1743. >>> new_parameter = paddle.static.create_parameter(name="X",
  1744. ... shape=[10, 23, 48],
  1745. ... dtype='float32')
  1746. >>> if new_parameter.is_parameter:
  1747. ... print("Current var is a Parameter")
  1748. ... else:
  1749. ... print("Current var is not a Parameter")
  1750. Current var is a Parameter
  1751. """
  1752. return self.desc.is_parameter()
  1753. @is_parameter.setter
  1754. def is_parameter(self, p):
  1755. self.desc.set_is_parameter(p)
  1756. @property
  1757. def name(self):
  1758. """
  1759. Indicating name of current Variable
  1760. **Notes: If it has two or more Variable share the same name in the same** :ref:`api_guide_Block_en` **, it means these Variable will share content in no-** Dygraph **mode. This is how we achieve Parameter sharing**
  1761. Examples:
  1762. .. code-block:: python
  1763. >>> import paddle.base as base
  1764. >>> cur_program = base.Program()
  1765. >>> cur_block = cur_program.current_block()
  1766. >>> new_variable = cur_block.create_var(name="X",
  1767. ... shape=[-1, 23, 48],
  1768. ... dtype='float32')
  1769. >>> print("name of current Var is: {}".format(new_variable.name))
  1770. name of current Var is: X
  1771. """
  1772. return self.desc.name()
  1773. @property
  1774. def grad_name(self):
  1775. """
  1776. Indicating name of the gradient Variable of current Variable.
  1777. **Notes: This is a read-only property. It simply returns name of
  1778. gradient Variable from a naming convention but doesn't guarantee
  1779. the gradient exists.**
  1780. Examples:
  1781. .. code-block:: python
  1782. >>> import paddle
  1783. >>> paddle.enable_static()
  1784. >>> x = paddle.static.data(name="x", shape=[-1, 23, 48], dtype='float32')
  1785. >>> print(x.grad_name)
  1786. x@GRAD
  1787. """
  1788. return self.name + "@GRAD"
  1789. @name.setter
  1790. def name(self, new_name):
  1791. self.desc.set_name(new_name)
  1792. @property
  1793. def shape(self):
  1794. """
  1795. Indicating shape of current Variable
  1796. **Notes: This is a read-only property**
  1797. Examples:
  1798. .. code-block:: python
  1799. >>> import paddle.base as base
  1800. >>> cur_program = base.Program()
  1801. >>> cur_block = cur_program.current_block()
  1802. >>> new_variable = cur_block.create_var(name="X",
  1803. ... shape=[-1, 23, 48],
  1804. ... dtype='float32')
  1805. >>> print("shape of current Var is: {}".format(new_variable.shape))
  1806. shape of current Var is: [-1, 23, 48]
  1807. """
  1808. # convert to tuple, make it as same as numpy API.
  1809. return tuple(self.desc.shape())
  1810. @property
  1811. def dtype(self):
  1812. """
  1813. Indicating data type of current Variable
  1814. **Notes: This is a read-only property**
  1815. Examples:
  1816. .. code-block:: python
  1817. >>> import paddle.base as base
  1818. >>> cur_program = base.Program()
  1819. >>> cur_block = cur_program.current_block()
  1820. >>> new_variable = cur_block.create_var(name="X",
  1821. ... shape=[-1, 23, 48],
  1822. ... dtype='float32')
  1823. >>> print("Dtype of current Var is: {}".format(new_variable.dtype))
  1824. Dtype of current Var is: paddle.float32
  1825. """
  1826. return self.desc.dtype()
  1827. @property
  1828. def lod_level(self):
  1829. """
  1830. Indicating ``LoD`` info of current Variable, please refer to :ref:`api_paddle_Tensor` to check the meaning
  1831. of ``LoD``
  1832. **Notes**:
  1833. **1. This is a read-only property**
  1834. **2. Don't support this property in** Dygraph **mode, it's value should be** ``0(int)``
  1835. Examples:
  1836. .. code-block:: python
  1837. >>> import paddle
  1838. >>> import paddle.base as base
  1839. >>> paddle.enable_static()
  1840. >>> cur_program = base.Program()
  1841. >>> cur_block = cur_program.current_block()
  1842. >>> new_variable = cur_block.create_var(name="X",
  1843. ... shape=[-1, 23, 48],
  1844. ... dtype='float32')
  1845. >>> print("LoD Level of current Var is: {}".format(new_variable.lod_level))
  1846. LoD Level of current Var is: 0
  1847. """
  1848. if self.type == core.VarDesc.VarType.SELECTED_ROWS:
  1849. raise Exception("SelectedRows DO NOT support lod")
  1850. if self.type == core.VarDesc.VarType.STRINGS:
  1851. return None
  1852. return self.desc.lod_level()
  1853. @property
  1854. def type(self):
  1855. """
  1856. Indicating Type of current Variable
  1857. **Notes: This is a read-only property**
  1858. Examples:
  1859. .. code-block:: python
  1860. >>> import paddle.base as base
  1861. >>> cur_program = base.Program()
  1862. >>> cur_block = cur_program.current_block()
  1863. >>> new_variable = cur_block.create_var(name="X",
  1864. ... shape=[-1, 23, 48],
  1865. ... dtype='float32')
  1866. >>> print("Type of current Var is: {}".format(new_variable.type))
  1867. Type of current Var is: VarType.LOD_TENSOR
  1868. """
  1869. return self.desc.type()
  1870. @property
  1871. def T(self):
  1872. """
  1873. Permute current Variable with its dimensions reversed.
  1874. If `n` is the dimensions of `x` , `x.T` is equivalent to `x.transpose([n-1, n-2, ..., 0])`.
  1875. Examples:
  1876. .. code-block:: python
  1877. >>> import paddle
  1878. >>> paddle.enable_static()
  1879. >>> x = paddle.ones(shape=[2, 3, 5])
  1880. >>> x_T = x.T
  1881. >>> exe = paddle.static.Executor()
  1882. >>> x_T_np = exe.run(paddle.static.default_main_program(), fetch_list=[x_T])[0]
  1883. >>> print(x_T_np.shape)
  1884. (5, 3, 2)
  1885. """
  1886. if len(self.shape) == 1:
  1887. return self
  1888. perm = []
  1889. for i in range(len(self.shape)):
  1890. perm.insert(0, i)
  1891. with unique_name.guard(self.block.program._name_generator):
  1892. out = self.block.create_var(
  1893. name=unique_name.generate_with_ignorable_key(
  1894. self.name + ".tmp"
  1895. ),
  1896. dtype=self.dtype,
  1897. type=self.type,
  1898. persistable=False,
  1899. stop_gradient=False,
  1900. )
  1901. input_shape = self.block.create_var(
  1902. name=unique_name.generate_with_ignorable_key(
  1903. self.name + ".tmp"
  1904. ),
  1905. dtype=self.dtype,
  1906. type=core.VarDesc.VarType.LOD_TENSOR,
  1907. persistable=False,
  1908. stop_gradient=False,
  1909. )
  1910. self.block.append_op(
  1911. type="transpose2",
  1912. inputs={"X": [self]},
  1913. outputs={"Out": [out], "XShape": [input_shape]},
  1914. attrs={"axis": perm},
  1915. )
  1916. return out
  1917. def clone(self):
  1918. """
  1919. Returns a new static Variable, which is the clone of the original static
  1920. Variable. It remains in the current graph, that is, the cloned Variable
  1921. provides gradient propagation. Calling ``out = tensor.clone()`` is same
  1922. as ``out = assign(tensor)`` .
  1923. Returns:
  1924. Variable, The cloned Variable.
  1925. Examples:
  1926. .. code-block:: python
  1927. >>> import paddle
  1928. >>> paddle.enable_static()
  1929. >>> # create a static Variable
  1930. >>> x = paddle.static.data(name='x', shape=[3, 2, 1])
  1931. >>> # create a cloned Variable
  1932. >>> y = x.clone()
  1933. """
  1934. with unique_name.guard(self.block.program._name_generator):
  1935. output = self.block.create_var(
  1936. name=unique_name.generate_with_ignorable_key(
  1937. self.name + "_clone"
  1938. ),
  1939. dtype=self.dtype,
  1940. type=self.type,
  1941. persistable=self.persistable,
  1942. stop_gradient=self.stop_gradient,
  1943. )
  1944. self.block.append_op(
  1945. type="assign",
  1946. inputs={"X": [self]},
  1947. outputs={"Out": [output]},
  1948. )
  1949. return output
  1950. def _set_error_clip(self, error_clip):
  1951. """
  1952. Set the error_clip.
  1953. Args:
  1954. error_clip(BaseErrorClipAttr) : The new error_clip.
  1955. Returns:
  1956. None
  1957. """
  1958. self.error_clip = error_clip
  1959. def _set_info(self, key, value):
  1960. """
  1961. Set key-value information for this variable.
  1962. Args:
  1963. key(str): Key for this information.
  1964. value(object): The value associated to the key.
  1965. Returns:
  1966. None
  1967. """
  1968. if not hasattr(self, "_info"):
  1969. self._info = {}
  1970. self._info[key] = value
  1971. def _get_info(self, key):
  1972. """
  1973. Get the information of this variable corresponding to key.
  1974. Args:
  1975. key(str): Key for this information.
  1976. Returns:
  1977. object
  1978. """
  1979. if hasattr(self, "_info") and key in self._info:
  1980. return self._info[key]
  1981. return None
  1982. def _slice_indices(self, slice, length):
  1983. """
  1984. Reference implementation for the slice.indices method.
  1985. """
  1986. # Compute step and length as integers.
  1987. step = 1 if slice.step is None else slice.step
  1988. # Raise ValueError for negative length or zero step.
  1989. if length < 0:
  1990. raise ValueError("length should not be negative")
  1991. if step == 0:
  1992. raise ValueError("slice step can not be zero")
  1993. # Find lower and upper bounds for start and stop.
  1994. lower = -1 if step < 0 else 0
  1995. upper = length - 1 if step < 0 else length
  1996. # Compute start.
  1997. if slice.start is None:
  1998. start = upper if step < 0 else lower
  1999. else:
  2000. start = slice.start
  2001. start = (
  2002. max(start + length, lower) if start < 0 else min(start, upper)
  2003. )
  2004. # Compute stop.
  2005. if slice.stop is None:
  2006. stop = lower if step < 0 else upper
  2007. else:
  2008. stop = slice.stop
  2009. stop = max(stop + length, lower) if stop < 0 else min(stop, upper)
  2010. return start, stop, step
  2011. def _detectEllipsis(self, item):
  2012. has_ellipsis = False
  2013. start = 0
  2014. end = len(self.shape)
  2015. for index, o in enumerate(item):
  2016. if o is Ellipsis:
  2017. if has_ellipsis:
  2018. raise ValueError("Index can have one ellipsis only.")
  2019. has_ellipsis = True
  2020. start = index
  2021. else:
  2022. if has_ellipsis:
  2023. end = index
  2024. return has_ellipsis, start, end
  2025. def _reconstructSliceinfo(self, item):
  2026. has_ellipsis, start, end = self._detectEllipsis(item)
  2027. if has_ellipsis:
  2028. newitem = []
  2029. for i in range(start):
  2030. newitem.append(item[i])
  2031. for i in range(start, end):
  2032. newitem.append(slice(None, None, None))
  2033. for i in range(end, len(item)):
  2034. newitem.append(item[i])
  2035. return newitem
  2036. else:
  2037. return None
  2038. def _detectContinuesSlice(self, item):
  2039. starts = []
  2040. ends = []
  2041. for index, o in enumerate(item):
  2042. if isinstance(o, int):
  2043. start = int(o)
  2044. if (index > 0 and index >= self.shape[index]) or (
  2045. index < 0 and (index + self.shape[index]) < 0
  2046. ):
  2047. raise IndexError("invalid index")
  2048. start = (
  2049. max(start + self.shape[index], 0)
  2050. if start < 0
  2051. else min(start, self.shape[index])
  2052. )
  2053. starts.append(start)
  2054. ends.append(start + 1)
  2055. elif isinstance(o, slice):
  2056. start, stop, step = self._slice_indices(o, self.shape[index])
  2057. if step == 1 or step == -1:
  2058. starts.append(start)
  2059. ends.append(stop)
  2060. else:
  2061. return False, None
  2062. else:
  2063. raise IndexError("Valid index accept int or slice or ellipsis")
  2064. return True, [starts, ends]
  2065. def _cloneVar(self, copy=False):
  2066. with unique_name.guard(self.block.program._name_generator):
  2067. if not copy:
  2068. return self.block.create_var(
  2069. name=unique_name.generate_with_ignorable_key(self.name),
  2070. dtype=self.dtype,
  2071. )
  2072. else:
  2073. return self
  2074. def _sliceVar(self, axes, starts, ends):
  2075. new_var = self._cloneVar()
  2076. self.block.append_op(
  2077. type="slice",
  2078. inputs={"Input": [self]},
  2079. outputs={"Out": [new_var]},
  2080. attrs={"axes": axes, "starts": starts, "ends": ends},
  2081. )
  2082. return new_var
  2083. def _concatVar(self, inputs, axis):
  2084. new_var = self._cloneVar()
  2085. self.block.append_op(
  2086. type="concat",
  2087. inputs={"X": inputs},
  2088. outputs={"Out": [new_var]},
  2089. attrs={
  2090. "axis": axis,
  2091. },
  2092. )
  2093. return new_var
  2094. def _sliceAndConcatVar(self, item, axis):
  2095. if isinstance(item, slice):
  2096. if self.shape[axis] < 0:
  2097. return self._cloneVar(True)
  2098. start, stop, step = self._slice_indices(item, self.shape[axis])
  2099. if step == 1:
  2100. return self._sliceVar([axis], [start], [stop])
  2101. else:
  2102. vars = []
  2103. if step > 0:
  2104. while start < stop:
  2105. vars.append(
  2106. self._sliceVar([axis], [start], [start + 1])
  2107. )
  2108. start += step
  2109. else:
  2110. while start > stop:
  2111. vars.append(
  2112. self._sliceVar([axis], [start], [start + 1])
  2113. )
  2114. start += step
  2115. return self._concatVar(vars, axis)
  2116. elif isinstance(item, int):
  2117. if self.shape[axis] < 0:
  2118. return self._cloneVar(True)
  2119. index = int(item)
  2120. if (index > 0 and index >= self.shape[axis]) or (
  2121. index < 0 and (index + self.shape[axis]) < 0
  2122. ):
  2123. raise IndexError("invalid index")
  2124. return self._sliceVar([axis], [index], [index + 1])
  2125. else:
  2126. raise IndexError("Valid index accept int or slice or tuple")
  2127. def __getitem__(self, item):
  2128. return _getitem_static(self, item)
  2129. def __setitem__(self, item, value):
  2130. from .dygraph.base import in_to_static_mode
  2131. if in_to_static_mode():
  2132. return _setitem_static(self, item, value)
  2133. else:
  2134. raise RuntimeError(
  2135. "In static mode, the __setitem__ (looks like: x[indices] = values) should not be used. Please use x = paddle.static.setitem(x, indices, values)"
  2136. )
  2137. def get_value(self, scope=None):
  2138. """
  2139. Get the value of variable in given scope.
  2140. Args:
  2141. scope(Scope, optional) : If `scope` is None, it will be set to global scope
  2142. obtained through 'paddle.static.global_scope()'. Otherwise, use `scope`.
  2143. Default: None
  2144. Returns:
  2145. Tensor, the value in given scope.
  2146. Examples:
  2147. .. code-block:: python
  2148. >>> import paddle
  2149. >>> import paddle.static as static
  2150. >>> import numpy as np
  2151. >>> paddle.enable_static()
  2152. >>> x = static.data(name="x", shape=[10, 10], dtype='float32')
  2153. >>> y = static.nn.fc(x, 10, name='fc')
  2154. >>> place = paddle.CPUPlace()
  2155. >>> exe = static.Executor(place)
  2156. >>> prog = paddle.static.default_main_program()
  2157. >>> exe.run(static.default_startup_program())
  2158. >>> inputs = np.ones((10, 10), dtype='float32')
  2159. >>> exe.run(prog, feed={'x': inputs}, fetch_list=[y, ])
  2160. >>> path = 'temp/tensor_'
  2161. >>> for var in prog.list_vars():
  2162. ... if var.persistable:
  2163. ... t = var.get_value()
  2164. ... paddle.save(t, path+var.name+'.pdtensor')
  2165. >>> for var in prog.list_vars():
  2166. ... if var.persistable:
  2167. ... t_load = paddle.load(path+var.name+'.pdtensor')
  2168. ... var.set_value(t_load)
  2169. """
  2170. # The 'framework' is a low-level module, and 'executor'
  2171. # can not be imported at the beginning of this file.
  2172. # Therefore, the above two modules are dynamically imported.
  2173. from .executor import global_scope
  2174. if scope is not None and not isinstance(scope, core._Scope):
  2175. raise TypeError(
  2176. f"`scope` should be None or `paddle.static.Scope` type, but received {type(scope)}."
  2177. )
  2178. if scope is None:
  2179. scope = global_scope()
  2180. var_temp = scope.find_var(self.name)
  2181. if var_temp is None:
  2182. raise ValueError(
  2183. f"Can not find Variable '{self.name}' in the Scope."
  2184. )
  2185. t = var_temp.get_tensor()
  2186. return t
  2187. def set_value(self, value, scope=None):
  2188. """
  2189. Set the value to the tensor in given scope.
  2190. Args:
  2191. value(Tensor/ndarray) : The value to be set.
  2192. scope(Scope, optional) : If `scope` is None, it will be set to global scope
  2193. obtained through 'paddle.static.global_scope()'. Otherwise, use `scope`.
  2194. Default: None
  2195. Returns:
  2196. None
  2197. Examples:
  2198. .. code-block:: python
  2199. >>> import paddle
  2200. >>> import paddle.static as static
  2201. >>> import numpy as np
  2202. >>> paddle.enable_static()
  2203. >>> x = static.data(name="x", shape=[10, 10], dtype='float32')
  2204. >>> y = static.nn.fc(x, 10, name='fc')
  2205. >>> place = paddle.CPUPlace()
  2206. >>> exe = static.Executor(place)
  2207. >>> prog = paddle.static.default_main_program()
  2208. >>> exe.run(static.default_startup_program())
  2209. >>> inputs = np.ones((10, 10), dtype='float32')
  2210. >>> exe.run(prog, feed={'x': inputs}, fetch_list=[y, ])
  2211. >>> path = 'temp/tensor_'
  2212. >>> for var in prog.list_vars():
  2213. ... if var.persistable:
  2214. ... t = var.get_value()
  2215. ... paddle.save(t, path+var.name+'.pdtensor')
  2216. >>> for var in prog.list_vars():
  2217. ... if var.persistable:
  2218. ... t_load = paddle.load(path+var.name+'.pdtensor')
  2219. ... var.set_value(t_load)
  2220. """
  2221. # The 'framework' is a low-level module, and 'executor'
  2222. # can not be imported at the beginning of this file.
  2223. # Therefore, the above two modules are dynamically imported.
  2224. from .executor import global_scope
  2225. if not (isinstance(value, np.ndarray) or hasattr(value, "__array__")):
  2226. raise TypeError(
  2227. f"`value` should be `numpy.ndarray` or `LoDTensor`, but received {type(value)}."
  2228. )
  2229. if scope is not None and not isinstance(scope, core._Scope):
  2230. raise TypeError(
  2231. f"`scope` should be None or `paddle.static.Scope` type, but received {type(scope)}."
  2232. )
  2233. if scope is None:
  2234. scope = global_scope()
  2235. var_temp = scope.find_var(self.name)
  2236. if var_temp is None:
  2237. raise ValueError(
  2238. f"Can not find Variable '{self.name}' in the Scope."
  2239. )
  2240. t = var_temp.get_tensor()
  2241. if hasattr(value, "shape"):
  2242. if isinstance(value.shape, (MethodType, FunctionType)):
  2243. value_shape = value.shape()
  2244. else:
  2245. value_shape = value.shape
  2246. if list(t.shape()) != list(value_shape):
  2247. raise ValueError(
  2248. f"{self.name} expected a shape {list(t.shape())}, but the received shape is {list(value_shape)}."
  2249. )
  2250. p = t._place()
  2251. if p.is_cpu_place():
  2252. place = core.CPUPlace()
  2253. elif p.is_cuda_pinned_place():
  2254. place = core.CUDAPinnedPlace()
  2255. elif p.is_xpu_place():
  2256. p = core.Place()
  2257. p.set_place(t._place())
  2258. place = core.XPUPlace(p.xpu_device_id())
  2259. elif p.is_custom_place():
  2260. p = core.Place()
  2261. p.set_place(t._place())
  2262. place = core.CustomPlace(
  2263. p.custom_device_type(), p.custom_device_id()
  2264. )
  2265. else:
  2266. p = core.Place()
  2267. p.set_place(t._place())
  2268. place = core.CUDAPlace(p.gpu_device_id())
  2269. t.set(value, place)
  2270. def size(self):
  2271. """
  2272. Returns the number of elements for current Variable, which is a int64 Variable with shape [] .
  2273. Returns:
  2274. Variable, the number of elements for current Variable
  2275. Examples:
  2276. .. code-block:: python
  2277. >>> import paddle
  2278. >>> paddle.enable_static()
  2279. >>> # create a static Variable
  2280. >>> x = paddle.static.data(name='x', shape=[3, 2, 1])
  2281. >>> # get the number of elements of the Variable
  2282. >>> y = x.size()
  2283. """
  2284. with unique_name.guard(self.block.program._name_generator):
  2285. output = self.block.create_var(
  2286. name=unique_name.generate_with_ignorable_key(
  2287. self.name + "_size"
  2288. ),
  2289. dtype=core.VarDesc.VarType.INT64,
  2290. )
  2291. self.block.append_op(
  2292. type="size",
  2293. inputs={"Input": [self]},
  2294. outputs={"Out": [output]},
  2295. )
  2296. return output
  2297. def _set_attr(self, name, val):
  2298. """
  2299. Set the value of attribute by attribute's name.
  2300. Args:
  2301. name(str): the attribute name.
  2302. val(int|str|list): the value of the attribute.
  2303. """
  2304. self._update_desc_attr(name, val)
  2305. def _has_attr(self, name):
  2306. """
  2307. Whether this Variable has the attribute with the name `name` or not.
  2308. Args:
  2309. name(str): the attribute name.
  2310. Returns:
  2311. bool, True if has this attribute.
  2312. """
  2313. return self.desc.has_attr(name)
  2314. def _remove_attr(self, name):
  2315. self.desc.remove_attr(name)
  2316. def _update_desc_attr(self, name, val):
  2317. """
  2318. Update the value of desc's attribute by attribute's name.
  2319. Args:
  2320. name(str): the attribute name.
  2321. val(int|str|list): the value of the attribute.
  2322. """
  2323. self.desc._set_attr(name, val)
  2324. @property
  2325. def attr_names(self):
  2326. """Get the names of all attributes defined."""
  2327. return self.desc.attr_names()
  2328. def attr(self, name):
  2329. """
  2330. Get the attribute by name.
  2331. Args:
  2332. name(str): the attribute name.
  2333. Returns:
  2334. int|str|list, The attribute value. The return value
  2335. can be any valid attribute type.
  2336. """
  2337. return self.desc.attr(name)
  2338. @property
  2339. def dist_attr(self):
  2340. """
  2341. Get distributed attribute of this Variable.
  2342. """
  2343. return self.desc.dist_attr
  2344. @dist_attr.setter
  2345. def dist_attr(self, dist_attr):
  2346. """
  2347. Set distributed attribute of this Variable.
  2348. """
  2349. self.desc.dist_attr = dist_attr
  2350. def get_all_op_protos():
  2351. """
  2352. Get all registered op proto from PaddlePaddle C++ end.
  2353. Returns:
  2354. list: list of OpProto.
  2355. """
  2356. protostrs = core.get_all_op_protos()
  2357. ret_values = []
  2358. for pbstr in protostrs:
  2359. op_proto = framework_pb2.OpProto.FromString(bytes(pbstr))
  2360. ret_values.append(op_proto)
  2361. return ret_values
  2362. class OpProtoHolder:
  2363. """
  2364. A global variable to hold all OpProtos from C++ as a map
  2365. """
  2366. @classmethod
  2367. def instance(cls):
  2368. if not hasattr(cls, "_instance"):
  2369. cls._instance = cls()
  2370. return cls._instance
  2371. def __init__(self):
  2372. assert not hasattr(
  2373. self.__class__, "_instance"
  2374. ), "Please use `instance()` to get OpProtoHolder object!"
  2375. op_protos = get_all_op_protos()
  2376. self.op_proto_map = {}
  2377. for proto in op_protos:
  2378. self.op_proto_map[proto.type] = proto
  2379. def get_op_proto(self, type):
  2380. """
  2381. Get OpProto by a type string.
  2382. Args:
  2383. type(str): The type that operator registered in C++ side.
  2384. Returns(framework_pb2.OpProto): The OpProto
  2385. """
  2386. if type not in self.op_proto_map:
  2387. raise ValueError('Operator "%s" has not been registered.' % type)
  2388. return self.op_proto_map[type]
  2389. def update_op_proto(self):
  2390. op_protos = get_all_op_protos()
  2391. custom_op_names = []
  2392. for proto in op_protos:
  2393. if proto.type not in self.op_proto_map:
  2394. self.op_proto_map[proto.type] = proto
  2395. custom_op_names.append(proto.type)
  2396. return custom_op_names
  2397. def has_op_proto(self, type):
  2398. return type in self.op_proto_map
  2399. @staticmethod
  2400. def generated_op_attr_names():
  2401. return {
  2402. core.op_proto_and_checker_maker.kOpRoleAttrName(),
  2403. core.op_proto_and_checker_maker.kOpRoleVarAttrName(),
  2404. core.op_proto_and_checker_maker.kOpNameScopeAttrName(),
  2405. core.op_proto_and_checker_maker.kOpCreationCallstackAttrName(),
  2406. core.op_proto_and_checker_maker.kOpDeviceAttrName(),
  2407. }
  2408. class Operator:
  2409. """
  2410. In Fluid, all the operation are represented by Operator, and Operator
  2411. is regarded as a build in an instruction of a Block. Users can use the
  2412. build in instructions to describe their neural network.
  2413. Args:
  2414. block(Block): The block has the current operator.
  2415. desc(core.OpDesc): The protobuf description of Operator.
  2416. type(str): The type of operator. Default None.
  2417. inputs(dict): The input of this Operator. it is a dictionary, for every
  2418. element, key is the input parameter name, and value is a list of
  2419. variables. Default None.
  2420. outputs(dict): The output of this Operator. it is a dictionary, for
  2421. every element, key is the input parameter name, and value is a list
  2422. of variables. Default None.
  2423. attrs(dict): The attributes of this Operator. it is a dictionary, for
  2424. every element, key is attribute name, and value is the attribute value.
  2425. The attribute type should be as same as the type registered in C++ side.
  2426. Default None.
  2427. Returns:
  2428. Operator: The initialized Operator.
  2429. Raises:
  2430. ValueError: If the passed input, output and attrs doesn't match the
  2431. initializing Operator's that registered in C++ side.
  2432. Notes:
  2433. The constructor of operator should not be invoked directly. Use
  2434. Block.append_op or Block._prepend_op instead.
  2435. Examples:
  2436. .. code-block:: python
  2437. >>> import paddle
  2438. >>> paddle.enable_static()
  2439. >>> cur_program = paddle.static.Program()
  2440. >>> cur_block = cur_program.current_block()
  2441. >>> var1 = cur_block.create_var(name="var1", shape=[-1, 23, 48], dtype='float32')
  2442. >>> var2 = cur_block.create_var(name="var2", shape=[-1, 23, 48], dtype='float32')
  2443. >>> var3 = cur_block.create_var(name="var3", shape=[-1, 23, 48], dtype='float32')
  2444. >>> var1 += var2 + var3
  2445. >>> cur_block.append_op(type="sum",
  2446. ... inputs={"X": [var1, var2, var3]},
  2447. ... outputs={"Out": [var1]})
  2448. """
  2449. OP_WITHOUT_KERNEL_SET = {
  2450. "feed",
  2451. "fetch",
  2452. "recurrent",
  2453. "go",
  2454. "conditional_block",
  2455. "pylayer",
  2456. "while",
  2457. "send",
  2458. "recv",
  2459. "listen_and_serv",
  2460. "fl_listen_and_serv",
  2461. "ncclInit",
  2462. "select",
  2463. "checkpoint_notify",
  2464. "gen_bkcl_id",
  2465. "c_gen_bkcl_id",
  2466. "gen_nccl_id",
  2467. "c_gen_nccl_id",
  2468. "c_comm_init",
  2469. "c_sync_calc_stream",
  2470. "c_sync_comm_stream",
  2471. "queue_generator",
  2472. "dequeue",
  2473. "enqueue",
  2474. "heter_listen_and_serv",
  2475. "c_wait_comm",
  2476. "c_wait_compute",
  2477. }
  2478. def __init__(
  2479. self, block, desc, type=None, inputs=None, outputs=None, attrs=None
  2480. ):
  2481. # read attr type index from op proto to avoid unexpected type
  2482. # conversions, e.g. narrowing conversion like double to float
  2483. try:
  2484. proto = OpProtoHolder.instance().get_op_proto(type)
  2485. self._attr_types = {}
  2486. for attr in proto.attrs:
  2487. self._attr_types[attr.name] = attr.type
  2488. except ValueError:
  2489. pass
  2490. if in_dygraph_mode():
  2491. if type is None:
  2492. raise ValueError(
  2493. "`type` to initialized an Operator can not be None."
  2494. )
  2495. self._type = type
  2496. self.attrs = attrs if attrs else {}
  2497. else:
  2498. self.block = block
  2499. self.desc = desc
  2500. # note: not add self.attrs here:
  2501. # https://github.com/PaddlePaddle/Paddle/pull/12583#pullrequestreview-145093173
  2502. op_attrs = attrs
  2503. if op_attrs is None:
  2504. op_attrs = {}
  2505. del attrs
  2506. # attr for static graph mode cuda graph
  2507. self._cuda_graph_attr = _current_cuda_graph_mode
  2508. # attr for OP AMP mode
  2509. # using dynamic import to avoid cyclic dependency
  2510. from paddle.static.amp.fp16_utils import DEFAULT_AMP_OPTIONS
  2511. self._amp_options: AmpOptions = DEFAULT_AMP_OPTIONS
  2512. # record the call path of op, only used in AutoParallel
  2513. self._struct_name = _full_name_struct()
  2514. op_maker = core.op_proto_and_checker_maker
  2515. if op_maker.kOpRoleAttrName() not in op_attrs:
  2516. op_attrs[
  2517. op_maker.kOpRoleAttrName()
  2518. ] = self.block.program._op_role
  2519. role_var_name = op_maker.kOpRoleVarAttrName()
  2520. if (
  2521. len(self.block.program._op_role_var) != 0
  2522. and role_var_name not in op_attrs
  2523. ):
  2524. op_attrs[role_var_name] = self.block.program._op_role_var
  2525. if role_var_name in op_attrs and len(op_attrs[role_var_name]) == 0:
  2526. del op_attrs[role_var_name]
  2527. if len(self.desc.type()) != 0:
  2528. # NOTE(Aurelius84): prog.clone() will lead that var.op is always None,
  2529. # we add this to fix the problem.
  2530. for arg in self.desc.output_arg_names():
  2531. if block.has_var(arg) and block.var(arg).op is None:
  2532. block.var(arg).op = self
  2533. return
  2534. if type is None:
  2535. raise ValueError(
  2536. "`type` to initialized an Operator can not be None."
  2537. )
  2538. else:
  2539. callstack_var_name = op_maker.kOpCreationCallstackAttrName()
  2540. op_attrs[callstack_var_name] = []
  2541. for frame in traceback.extract_stack():
  2542. op_attrs[callstack_var_name].append(
  2543. f' File "{frame[0]}", line {frame[1]}, in {frame[2]}'
  2544. )
  2545. op_attrs[callstack_var_name].append(f" {frame[3]}")
  2546. self.desc.set_type(type)
  2547. proto = OpProtoHolder.instance().get_op_proto(type)
  2548. namescope_var_name = op_maker.kOpNameScopeAttrName()
  2549. op_attrs[namescope_var_name] = _full_name_scope()
  2550. # set device for op with kernels, give warning for op without kernels
  2551. # when force_cpu and device_guard are used at the same time, a warning will be given.
  2552. # TODO(zhangting2020): when force_cpu is removed, clear warning below.
  2553. if _current_device is not None:
  2554. if self._has_kernel(type):
  2555. op_device = op_maker.kOpDeviceAttrName()
  2556. op_attrs[op_device] = _current_device
  2557. else:
  2558. warnings.warn(
  2559. "The Op(%s) is not support to set device." % type
  2560. )
  2561. if "force_cpu" in op_attrs:
  2562. if (
  2563. type == "less_than"
  2564. and op_attrs["force_cpu"] is not None
  2565. ) or op_attrs["force_cpu"] is not False:
  2566. warnings.warn(
  2567. "The Attr(force_cpu) of Op(%s) will be deprecated in the future, "
  2568. "please use 'device_guard' instead. 'device_guard' has higher priority when they are "
  2569. "used at the same time." % type
  2570. )
  2571. if _current_pipeline_stage is not None:
  2572. pipeline_attr_name = (
  2573. "pipeline_stage" + core.kAutoParallelSuffix()
  2574. )
  2575. self._update_desc_attr(
  2576. pipeline_attr_name, _current_pipeline_stage
  2577. )
  2578. def find_name(var_list, name):
  2579. for var_name in var_list:
  2580. if var_list[var_name] is not None and var_name == name:
  2581. return True
  2582. return False
  2583. if inputs is not None:
  2584. for in_proto in proto.inputs:
  2585. found = find_name(inputs, in_proto.name)
  2586. assert (
  2587. found or in_proto.dispensable
  2588. ), f"Input {in_proto.name} not found"
  2589. if found:
  2590. in_args = inputs[in_proto.name]
  2591. if not isinstance(in_args, (list, tuple)):
  2592. in_args = [in_args]
  2593. if not in_proto.duplicable and len(in_args) > 1:
  2594. raise ValueError(
  2595. "Input %s expects only one input, but %d are given."
  2596. % (in_proto.name, len(in_args))
  2597. )
  2598. in_arg_names = []
  2599. for index, arg in enumerate(in_args):
  2600. if isinstance(arg, str):
  2601. in_arg_names.append(arg)
  2602. elif isinstance(arg, bytes):
  2603. in_arg_names.append(arg.decode())
  2604. elif isinstance(arg, (Variable, core.eager.Tensor)):
  2605. in_arg_names.append(arg.name)
  2606. else:
  2607. raise TypeError(
  2608. f"The type of '%{in_proto.name}' in operator {type} should be "
  2609. f"one of [str, bytes, Variable]. but received : {arg}"
  2610. )
  2611. self.desc.set_input(in_proto.name, in_arg_names)
  2612. else:
  2613. self.desc.set_input(in_proto.name, [])
  2614. if outputs is not None:
  2615. for m in proto.outputs:
  2616. if (m.name not in outputs) and m.dispensable:
  2617. continue
  2618. # FIXME: The outputs of primitive operator currently
  2619. # doesn't include intermediate output as it will be dropped
  2620. # in operator codegen, such as xshape output of reshape2.
  2621. # It will fixed when the operator codegen support
  2622. # intermediate output.
  2623. if core._is_bwd_prim_enabled():
  2624. if not (
  2625. (m.name in outputs)
  2626. or m.dispensable
  2627. or m.intermediate
  2628. ):
  2629. raise ValueError(
  2630. "Incorrect setting for output(s) of "
  2631. f'operator "{type}", should set: [{m.name}].'
  2632. )
  2633. else:
  2634. if not ((m.name in outputs) or m.dispensable):
  2635. raise ValueError(
  2636. "Incorrect setting for output(s) of "
  2637. f'operator "{type}", should set: [{m.name}].'
  2638. )
  2639. for out_proto in proto.outputs:
  2640. if out_proto.name not in outputs:
  2641. continue
  2642. out_args = outputs[out_proto.name]
  2643. if not isinstance(out_args, list):
  2644. out_args = [out_args]
  2645. if not out_proto.duplicable and len(out_args) > 1:
  2646. raise ValueError(
  2647. "Output %s expects only one output, but %d are given."
  2648. % (out_proto.name, len(out_args))
  2649. )
  2650. out_arg_names = []
  2651. for arg in out_args:
  2652. if isinstance(arg, str):
  2653. out_arg_names.append(arg)
  2654. else:
  2655. out_arg_names.append(arg.name)
  2656. # TODO(minqiyang): could we remove variable's op in static graph mode?
  2657. if not in_dygraph_mode():
  2658. if isinstance(arg, str):
  2659. block.var(arg).op = self
  2660. else:
  2661. arg.op = self
  2662. self.desc.set_output(out_proto.name, out_arg_names)
  2663. extra_attrs_map = core.get_op_extra_attrs(type)
  2664. if op_attrs is not None:
  2665. if not isinstance(op_attrs, dict):
  2666. raise TypeError("'attrs' should be a dict.")
  2667. for attr in proto.attrs:
  2668. attr_name = attr.name
  2669. if (attr_name not in op_attrs) or (
  2670. op_attrs[attr_name] is None
  2671. ):
  2672. continue
  2673. attr_val = op_attrs[attr_name]
  2674. self._update_desc_attr(attr_name, attr_val)
  2675. for attr_name in extra_attrs_map.keys():
  2676. if os.environ.get("FLAGS_print_extra_attrs", "0") == "1":
  2677. warnings.warn(f"op {type} use extra_attr: {attr_name}")
  2678. if (attr_name not in op_attrs) or (
  2679. op_attrs[attr_name] is None
  2680. ):
  2681. self._update_desc_attr(
  2682. attr_name, extra_attrs_map[attr_name]
  2683. )
  2684. else:
  2685. self._update_desc_attr(attr_name, op_attrs[attr_name])
  2686. if os.environ.get("FLAGS_print_extra_attrs", "0") == "1":
  2687. if type in extra_op_attrs:
  2688. attrs = extra_op_attrs.get(type, [])
  2689. for attr in attrs:
  2690. if attr in op_attrs.keys():
  2691. warnings.warn(
  2692. f"op {type} use extra_attr: {attr}"
  2693. )
  2694. if type in special_op_attrs:
  2695. attrs = special_op_attrs.get(type, [])
  2696. for attr in attrs:
  2697. a_name = list(attr.keys())[0]
  2698. default_value = list(attr.values())[0]
  2699. if (
  2700. a_name in op_attrs.keys()
  2701. and default_value != op_attrs[a_name]
  2702. ):
  2703. warnings.warn(
  2704. f"op {type}'s attr {a_name} = {op_attrs[a_name]} is not the default value: {default_value}"
  2705. )
  2706. # proto.attrs doesn't include ipu_index
  2707. if core.is_compiled_with_ipu():
  2708. if global_ipu_index >= 0:
  2709. self._update_desc_attr(
  2710. ipu_index_attr_name, global_ipu_index
  2711. )
  2712. if global_ipu_stage >= 0:
  2713. self._update_desc_attr(
  2714. ipu_stage_attr_name, global_ipu_stage
  2715. )
  2716. self.desc.check_attrs()
  2717. if self._has_kernel(type):
  2718. self.desc.infer_var_type(self.block.desc)
  2719. self.desc.infer_shape(self.block.desc)
  2720. def _has_kernel(self, op_type):
  2721. return op_type not in self.OP_WITHOUT_KERNEL_SET
  2722. def to_string(self, throw_on_error):
  2723. """
  2724. Get debug string.
  2725. Args:
  2726. throw_on_error(bool): Whether to raise exception if self is not
  2727. initialized.
  2728. Returns:
  2729. str: The debug string.
  2730. """
  2731. protostr = self.desc.serialize_to_string()
  2732. proto = framework_pb2.OpDesc.FromString(bytes(protostr))
  2733. return _debug_string_(proto, throw_on_error)
  2734. def _to_readable_code(self, skip_op_callstack=True):
  2735. """
  2736. Get readable debug string of Operator.
  2737. .. note::
  2738. If you want to get the debug string in protobuf format,
  2739. please use :code:`to_string` method.
  2740. Args:
  2741. skip_op_callstack(bool): whether to skip parsing Operator's attribute
  2742. op_callstack, default value is True
  2743. Returns:
  2744. string: The formatted Operator string.
  2745. Examples:
  2746. .. code-block:: python
  2747. >>> import paddle
  2748. >>> paddle.enable_static()
  2749. >>> cur_program = paddle.static.Program()
  2750. >>> cur_block = cur_program.current_block()
  2751. >>> var = cur_block.create_var(name="X",
  2752. ... shape=[-1, 23, 48],
  2753. ... dtype='float32')
  2754. >>> new_op = cur_block.append_op(type="abs",
  2755. ... inputs={"X": [var]},
  2756. ... outputs={"Out": [var]})
  2757. >>> print(new_op._to_readable_code())
  2758. """
  2759. assert isinstance(
  2760. skip_op_callstack, bool
  2761. ), f"skip_op_callstack parameter's type is error, expect bool, received {type(skip_op_callstack)}"
  2762. outputs_str = "{"
  2763. for i in range(0, len(self.output_names)):
  2764. outputs_str += f"{self.output_names[i]}="
  2765. o = self.output(self.output_names[i])
  2766. outputs_str += f"{o}"
  2767. if i != len(self.output_names) - 1:
  2768. outputs_str += ", "
  2769. outputs_str += "}"
  2770. inputs_str = "{"
  2771. for i in range(0, len(self.input_names)):
  2772. inputs_str += f"{self.input_names[i]}="
  2773. o = self.input(self.input_names[i])
  2774. inputs_str += f"{o}"
  2775. if i != len(self.input_names) - 1:
  2776. inputs_str += ", "
  2777. inputs_str += "}"
  2778. attr_names = sorted(self.attr_names)
  2779. attrs_str = ""
  2780. for i in range(0, len(attr_names)):
  2781. name = attr_names[i]
  2782. if skip_op_callstack and name == "op_callstack":
  2783. continue
  2784. attr_type = self.desc.attr_type(name, True)
  2785. if attr_type == core.AttrType.VAR:
  2786. attr_var_name = self.desc.attr(name, True).name()
  2787. a = f"{name} = Var['{attr_var_name}']"
  2788. attrs_str += a
  2789. if i != len(attr_names) - 1:
  2790. attrs_str += ", "
  2791. continue
  2792. if attr_type == core.AttrType.VARS:
  2793. attr_var_names = [
  2794. "'%s'" % var.name() for var in self.desc.attr(name, True)
  2795. ]
  2796. a = "{name} = Vars[{value}]".format(
  2797. name=name, value=",".join(attr_var_names)
  2798. )
  2799. attrs_str += a
  2800. if i != len(attr_names) - 1:
  2801. attrs_str += ", "
  2802. continue
  2803. if attr_type == core.AttrType.BLOCK:
  2804. a = f"{name} = block[{self._block_attr_id(name)}]"
  2805. attrs_str += a
  2806. if i != len(attr_names) - 1:
  2807. attrs_str += ", "
  2808. continue
  2809. if attr_type == core.AttrType.BLOCKS:
  2810. a = f"{name} = blocks{self._blocks_attr_ids(name)}"
  2811. attrs_str += a
  2812. if i != len(attr_names) - 1:
  2813. attrs_str += ", "
  2814. continue
  2815. # it is bytes of serialized protobuf
  2816. if (
  2817. is_compiled_with_cinn()
  2818. and self.type == "cinn_launch"
  2819. and name == "compilation_key"
  2820. ):
  2821. key = self.desc.attr(name)
  2822. v = core.get_serialize_comile_key(key)
  2823. prog = Program()
  2824. prog = prog.parse_from_string(v)
  2825. s = prog._to_readable_code()
  2826. lines = s.split("\n")
  2827. value = "\n".join([" " + line for line in lines])
  2828. value = "\n" + value
  2829. else:
  2830. value = self.desc.attr(name)
  2831. a = f"{name} = {value}"
  2832. attrs_str += a
  2833. if i != len(attr_names) - 1:
  2834. attrs_str += ", "
  2835. from paddle.distributed.auto_parallel.static.dist_context import (
  2836. get_default_distributed_context,
  2837. )
  2838. dist_context = get_default_distributed_context()
  2839. dist_op = dist_context.get_dist_op_for_program(self)
  2840. if dist_op is not None:
  2841. attrs_str += ", {name} = {value}".format(
  2842. name="dist_attr", value=dist_op
  2843. )
  2844. if outputs_str != "{}":
  2845. op_str = (
  2846. f"{outputs_str} = {self.type}(inputs={inputs_str}, {attrs_str})"
  2847. )
  2848. else:
  2849. op_str = f"{self.type}(inputs={inputs_str}, {attrs_str})"
  2850. return op_str
  2851. def __str__(self):
  2852. return self._to_readable_code()
  2853. __repr__ = __str__
  2854. @property
  2855. def type(self):
  2856. return self.desc.type()
  2857. def input(self, name):
  2858. r"""
  2859. Get the input arguments according to the input parameter name.
  2860. Args:
  2861. name(str): The input parameter name.
  2862. Returns:
  2863. list, return the list of argument names that associated with \
  2864. the specific parameter name.
  2865. """
  2866. return self.desc.input(name)
  2867. def _rename_input(self, old_name, new_name):
  2868. """
  2869. Rename the `old_name` to `new_name`.
  2870. Args:
  2871. old_name(str): The old name of the Operator's input.
  2872. new_name(str): The new name of the Operator's input.
  2873. Returns:
  2874. None
  2875. """
  2876. self.desc._rename_input(old_name, new_name)
  2877. def _rename_output(self, old_name, new_name):
  2878. """
  2879. Rename the `old_name` to `new_name`.
  2880. Args:
  2881. old_name(str): The old name of the Operator's output.
  2882. new_name(str): The new name of the Operator's output.
  2883. Returns:
  2884. None
  2885. """
  2886. self.desc._rename_output(old_name, new_name)
  2887. @property
  2888. def input_names(self):
  2889. return self.desc.input_names()
  2890. @property
  2891. def input_arg_names(self):
  2892. return self.desc.input_arg_names()
  2893. @property
  2894. def output_arg_names(self):
  2895. return self.desc.output_arg_names()
  2896. def output(self, name):
  2897. r"""
  2898. Get output arguments by the output parameter name.
  2899. Args:
  2900. name(str): The output parameter name.
  2901. Returns:
  2902. list: return the list of argument names associated with \
  2903. the specific parameter name.
  2904. """
  2905. return self.desc.output(name)
  2906. @property
  2907. def output_names(self):
  2908. return self.desc.output_names()
  2909. @property
  2910. def idx(self):
  2911. for i, op in enumerate(self.block.ops):
  2912. if op == self:
  2913. return i
  2914. raise ValueError(
  2915. "Can't find op itself in it's block. It could be a bug of Paddle."
  2916. )
  2917. def has_attr(self, name):
  2918. """
  2919. Whether this Operator has the attribute with name or not.
  2920. Args:
  2921. name(str): the attribute name.
  2922. Returns:
  2923. bool: True if has this attribute.
  2924. """
  2925. return self.desc.has_attr(name)
  2926. def attr_type(self, name):
  2927. """
  2928. Get the type of attribute by attribute's name.
  2929. Args:
  2930. name(str): the attribute name.
  2931. Returns:
  2932. core.AttrType: the attribute type.
  2933. """
  2934. return self.desc.attr_type(name, True)
  2935. def _set_attr(self, name, val):
  2936. """
  2937. Set the value of attribute by attribute's name.
  2938. Args:
  2939. name(str): the attribute name.
  2940. val(bool|int|str|float|list): the value of the attribute.
  2941. Raises:
  2942. ValueError: If the type of value doesn't match with desc.attr_type(name).
  2943. """
  2944. self._update_desc_attr(name, val)
  2945. def _remove_attr(self, name):
  2946. self.desc.remove_attr(name)
  2947. def _update_desc_attr(self, name, val):
  2948. """
  2949. Update the value of desc's attribute by attribute's name.
  2950. Args:
  2951. name(str): the attribute name.
  2952. val(bool|int|str|float|list): the value of the attribute.
  2953. Raises:
  2954. ValueError: If the type of value doesn't match with desc.attr_type(name).
  2955. """
  2956. if isinstance(val, Variable):
  2957. self.desc.set_var_attr(name, val.desc)
  2958. elif isinstance(val, list) and _all_is_type(val, Variable):
  2959. self.desc.set_vars_attr(name, [v.desc for v in val])
  2960. elif isinstance(val, Block):
  2961. self.desc.set_block_attr(name, val.desc)
  2962. elif isinstance(val, list) and val and _all_is_type(val, Block):
  2963. self.desc.set_blocks_attr(name, [v.desc for v in val])
  2964. elif isinstance(val, (core.BlockDesc, core.ProgramDesc)):
  2965. self.desc.set_serialized_attr(name, val.serialize_to_string())
  2966. else:
  2967. self._update_desc_plain_attr(name, val)
  2968. def _update_desc_plain_attr(self, name, val):
  2969. desc = self.desc
  2970. if not hasattr(self, "_attr_types") or (name not in self._attr_types):
  2971. desc._set_attr(name, val)
  2972. return
  2973. type_index = self._attr_types[name]
  2974. # if the required attribute is a SCALAR, pass as-is
  2975. if type_index == core.AttrType.SCALAR:
  2976. desc._set_scalar_attr(name, wrap_as_scalar(val))
  2977. elif type_index == core.AttrType.SCALARS:
  2978. desc._set_scalars_attr(name, wrap_as_scalars(val))
  2979. elif type_index == core.AttrType.BOOL:
  2980. desc._set_bool_attr(name, val)
  2981. elif type_index == core.AttrType.INT:
  2982. desc._set_int32_attr(name, val)
  2983. elif type_index == core.AttrType.LONG:
  2984. desc._set_int64_attr(name, val)
  2985. elif type_index == core.AttrType.FLOAT:
  2986. desc._set_float32_attr(name, val)
  2987. elif type_index == core.AttrType.FLOAT64:
  2988. desc._set_float64_attr(name, val)
  2989. elif type_index == core.AttrType.STRING:
  2990. desc._set_str_attr(name, val)
  2991. elif type_index == core.AttrType.BOOLS:
  2992. desc._set_bools_attr(name, val)
  2993. elif type_index == core.AttrType.INTS:
  2994. desc._set_int32s_attr(name, val)
  2995. elif type_index == core.AttrType.LONGS:
  2996. desc._set_int64s_attr(name, val)
  2997. elif type_index == core.AttrType.FLOATS:
  2998. desc._set_float32s_attr(name, val)
  2999. elif type_index == core.AttrType.FLOAT64S:
  3000. desc._set_float64s_attr(name, val)
  3001. elif type_index == core.AttrType.STRINGS:
  3002. desc._set_strs_attr(name, val)
  3003. else:
  3004. # defaults to old methods
  3005. desc._set_attr(name, val)
  3006. @property
  3007. def attr_names(self):
  3008. return self.desc.attr_names(True)
  3009. def attr(self, name):
  3010. """
  3011. Get the attribute by name.
  3012. Args:
  3013. name(str): the attribute name.
  3014. Returns:
  3015. bool|int|str|float|list: The attribute value. The return value
  3016. can be any valid attribute type.
  3017. """
  3018. return self.desc.attr(name)
  3019. def _block_attr_id(self, name):
  3020. """
  3021. Get the block attribute's id by name.
  3022. Args:
  3023. name(str): the attribute name.
  3024. Returns:
  3025. int: the block index.
  3026. """
  3027. return self.desc._block_attr_id(name)
  3028. def _block_attr(self, name):
  3029. """
  3030. Get the block attribute by name.
  3031. Args:
  3032. name(str): the attribute name.
  3033. Returns:
  3034. block: the block attribute.
  3035. """
  3036. id = self._block_attr_id(name)
  3037. assert id >= 0 and id < len(self.block.program.blocks)
  3038. return self.block.program.blocks[id]
  3039. def _blocks_attr(self, name):
  3040. """
  3041. Get the blocks attribute by name.
  3042. Args:
  3043. name(str): the attribute name.
  3044. Returns:
  3045. list: list of the blocks attribute.
  3046. """
  3047. attrs = []
  3048. for i in self._blocks_attr_ids(name):
  3049. assert i >= 0 and i < len(self.block.program.blocks)
  3050. attrs.append(self.block.program.blocks[i])
  3051. return attrs
  3052. def _blocks_attr_ids(self, name):
  3053. """
  3054. Get the blocks attribute's ids by name.
  3055. Args:
  3056. name(str): the attribute name.
  3057. Returns:
  3058. list: list of the blocks ids.
  3059. """
  3060. return self.desc._blocks_attr_ids(name)
  3061. def _var_attr(self, name):
  3062. """
  3063. Get the Variable attribute by name.
  3064. Args:
  3065. name(str): the attribute name.
  3066. Returns:
  3067. Variable: the Variable attribute.
  3068. """
  3069. attr_type = self.desc.attr_type(name, True)
  3070. assert (
  3071. attr_type == core.AttrType.VAR
  3072. ), f"Required type attr({name}) is Variable, but received {attr_type}"
  3073. attr_var_name = self.desc.attr(name, True).name()
  3074. return self.block._var_recursive(attr_var_name)
  3075. def _vars_attr(self, name):
  3076. """
  3077. Get the Variables attribute by name.
  3078. Args:
  3079. name(str): the attribute name.
  3080. Returns:
  3081. Variables: the Variables attribute.
  3082. """
  3083. attr_type = self.desc.attr_type(name, True)
  3084. assert (
  3085. attr_type == core.AttrType.VARS
  3086. ), f"Required type attr({name}) is list[Variable], but received {attr_type}"
  3087. attr_vars = [
  3088. self.block._var_recursive(var.name())
  3089. for var in self.desc.attr(name, True)
  3090. ]
  3091. return attr_vars
  3092. def all_attrs(self):
  3093. """
  3094. Get the attribute dict.
  3095. Returns:
  3096. dict: The Operator's attribute dict, name->attr.
  3097. """
  3098. attr_names = self.attr_names
  3099. attr_map = {}
  3100. for n in attr_names:
  3101. attr_type = self.desc.attr_type(n, True)
  3102. if attr_type == core.AttrType.BLOCK:
  3103. attr_map[n] = self._block_attr(n)
  3104. elif attr_type == core.AttrType.BLOCKS:
  3105. attr_map[n] = self._blocks_attr(n)
  3106. elif attr_type == core.AttrType.VAR:
  3107. attr_map[n] = self._var_attr(n)
  3108. elif attr_type == core.AttrType.VARS:
  3109. attr_map[n] = self._vars_attr(n)
  3110. else:
  3111. attr_map[n] = self.attr(n)
  3112. return attr_map
  3113. def _is_optimize_op(self):
  3114. op_maker = core.op_proto_and_checker_maker
  3115. OPTIMIZE = core.op_proto_and_checker_maker.OpRole.Optimize
  3116. if not self.desc.has_attr(op_maker.kOpRoleAttrName()):
  3117. return False
  3118. op_role = self.desc.attr(op_maker.kOpRoleAttrName())
  3119. if op_role & int(OPTIMIZE):
  3120. return True
  3121. return False
  3122. def _is_backward_op(self):
  3123. op_maker = core.op_proto_and_checker_maker
  3124. BACKWARD = core.op_proto_and_checker_maker.OpRole.Backward
  3125. if not self.desc.has_attr(op_maker.kOpRoleAttrName()):
  3126. return False
  3127. op_role = self.desc.attr(op_maker.kOpRoleAttrName())
  3128. if op_role & int(BACKWARD):
  3129. return True
  3130. return False
  3131. @property
  3132. def dist_attr(self):
  3133. """
  3134. Get distributed attribute of this Variable.
  3135. """
  3136. return self.desc.dist_attr
  3137. @dist_attr.setter
  3138. def dist_attr(self, dist_attr):
  3139. """
  3140. Set distributed attribute of this Variable.
  3141. """
  3142. self.desc.dist_attr = dist_attr
  3143. def set_amp_options(self, amp_options):
  3144. """
  3145. Set auto cast attribute of this Operator.
  3146. Args:
  3147. amp_options (AmpOptions): AmpOptions of this Operator.
  3148. """
  3149. self._amp_options = amp_options
  3150. @property
  3151. def amp_options(self):
  3152. """
  3153. Get auto cast attribute of this Operator.
  3154. Returns:
  3155. bool: AmpOptions of this Operator.
  3156. """
  3157. return self._amp_options
  3158. @property
  3159. def struct_name(self):
  3160. return self._struct_name
  3161. @struct_name.setter
  3162. def struct_name(self, struct_name):
  3163. self._struct_name = struct_name
  3164. @signature_safe_contextmanager
  3165. def _stride_in_no_check_dy2st_diff():
  3166. global _stride_in_no_check_dy2st_diff_mode
  3167. _stride_in_no_check_dy2st_diff_mode = True
  3168. try:
  3169. yield
  3170. finally:
  3171. _stride_in_no_check_dy2st_diff_mode = False
  3172. def check_if_to_static_diff_with_dygraph(op_type, inplace_map, outputs):
  3173. if op_type in {"while", "conditional_block"}:
  3174. # Dont' need check while and conditional_block, it is only a wrapper of inner ops
  3175. # we will stuck in inner op.
  3176. return
  3177. if outputs is not None:
  3178. for k, v in outputs.items():
  3179. if isinstance(v, Variable):
  3180. if v.is_view_var and not (
  3181. op_type == "set_value"
  3182. and inplace_map.get("Input", None) == "Out"
  3183. ):
  3184. raise ValueError(
  3185. f"Sorry about what's happened. In to_static mode, {op_type}'s output variable {k} is a viewed Tensor in dygraph. This will result in inconsistent calculation behavior between dynamic and static graphs. If you are sure it is safe, you can call with paddle.base.framework._stride_in_no_check_dy2st_diff() in your safe code block."
  3186. )
  3187. elif isinstance(v, list):
  3188. for var in v:
  3189. if isinstance(var, Variable):
  3190. if var.is_view_var and not (
  3191. op_type == "set_value"
  3192. and inplace_map.get("Input", None) == "Out"
  3193. ):
  3194. raise ValueError(
  3195. f"Sorry about what's happend. In to_static mode, {op_type}'s output variable {k} is a viewed Tensor in dygraph. This will result in inconsistent calculation behavior between dynamic and static graphs. If you are sure it is safe, you can call with paddle.base.framework._stride_in_no_check_dy2st_diff() in your safe code block."
  3196. )
  3197. def record_is_view_var(op_type, inputs, outputs):
  3198. if op_type == "slice":
  3199. if inputs is not None and isinstance(inputs["Input"], list):
  3200. if hasattr(inputs["Input"][0], "is_view_var"):
  3201. inputs["Input"][0].is_view_var = True
  3202. else:
  3203. if hasattr(inputs["Input"], "is_view_var"):
  3204. inputs["Input"].is_view_var = True
  3205. if outputs is not None and isinstance(outputs["Out"], list):
  3206. if hasattr(outputs["Out"][0], "is_view_var"):
  3207. outputs["Out"][0].is_view_var = True
  3208. else:
  3209. if hasattr(outputs["Out"], "is_view_var"):
  3210. outputs["Out"].is_view_var = True
  3211. elif op_type == "strided_slice":
  3212. if inputs is not None and isinstance(inputs["Input"], list):
  3213. if hasattr(inputs["Input"][0], "is_view_var"):
  3214. inputs["Input"][0].is_view_var = True
  3215. else:
  3216. if hasattr(inputs["Input"], "is_view_var"):
  3217. inputs["Input"].is_view_var = True
  3218. if outputs is not None and isinstance(outputs["Out"], list):
  3219. if hasattr(outputs["Out"][0], "is_view_var"):
  3220. outputs["Out"][0].is_view_var = True
  3221. else:
  3222. if hasattr(outputs["Out"], "is_view_var"):
  3223. outputs["Out"].is_view_var = True
  3224. elif op_type == "index_select":
  3225. if inputs is not None and isinstance(inputs["X"], list):
  3226. if hasattr(inputs["X"][0], "is_view_var"):
  3227. inputs["X"][0].is_view_var = True
  3228. else:
  3229. if hasattr(inputs["X"], "is_view_var"):
  3230. inputs["X"].is_view_var = True
  3231. if outputs is not None and isinstance(outputs["Out"], list):
  3232. if hasattr(outputs["Out"][0], "is_view_var"):
  3233. outputs["Out"][0].is_view_var = True
  3234. else:
  3235. if hasattr(outputs["Out"], "is_view_var"):
  3236. outputs["Out"].is_view_var = True
  3237. elif op_type == "split":
  3238. if inputs is not None and isinstance(inputs["X"], list):
  3239. if hasattr(inputs["X"][0], "is_view_var"):
  3240. inputs["X"][0].is_view_var = True
  3241. else:
  3242. if hasattr(inputs["X"], "is_view_var"):
  3243. inputs["X"].is_view_var = True
  3244. if outputs is not None:
  3245. for out in outputs["Out"]:
  3246. if hasattr(out, "is_view_var"):
  3247. out.is_view_var = True
  3248. elif op_type == "unsqueeze" or op_type == "unsqueeze2":
  3249. if inputs is not None and isinstance(inputs["X"], list):
  3250. if hasattr(inputs["X"][0], "is_view_var"):
  3251. inputs["X"][0].is_view_var = True
  3252. else:
  3253. if hasattr(inputs["X"], "is_view_var"):
  3254. inputs["X"].is_view_var = True
  3255. if outputs is not None and isinstance(outputs["Out"], list):
  3256. if hasattr(outputs["Out"][0], "is_view_var"):
  3257. outputs["Out"][0].is_view_var = True
  3258. else:
  3259. if hasattr(outputs["Out"], "is_view_var"):
  3260. outputs["Out"].is_view_var = True
  3261. elif op_type == "squeeze" or op_type == "squeeze2":
  3262. if inputs is not None and isinstance(inputs["X"], list):
  3263. if hasattr(inputs["X"][0], "is_view_var"):
  3264. inputs["X"][0].is_view_var = True
  3265. else:
  3266. if hasattr(inputs["X"], "is_view_var"):
  3267. inputs["X"].is_view_var = True
  3268. if outputs is not None and isinstance(outputs["Out"], list):
  3269. if hasattr(outputs["Out"][0], "is_view_var"):
  3270. outputs["Out"][0].is_view_var = True
  3271. else:
  3272. if hasattr(outputs["Out"], "is_view_var"):
  3273. outputs["Out"].is_view_var = True
  3274. elif op_type == "transpose" or op_type == "transpose2":
  3275. if inputs is not None and isinstance(inputs["X"], list):
  3276. if hasattr(inputs["X"][0], "is_view_var"):
  3277. inputs["X"][0].is_view_var = True
  3278. else:
  3279. if hasattr(inputs["X"], "is_view_var"):
  3280. inputs["X"].is_view_var = True
  3281. if outputs is not None and isinstance(outputs["Out"], list):
  3282. if hasattr(outputs["Out"][0], "is_view_var"):
  3283. outputs["Out"][0].is_view_var = True
  3284. else:
  3285. if hasattr(outputs["Out"], "is_view_var"):
  3286. outputs["Out"].is_view_var = True
  3287. elif op_type == "unbind":
  3288. if inputs is not None and isinstance(inputs["X"], list):
  3289. if hasattr(inputs["X"][0], "is_view_var"):
  3290. inputs["X"][0].is_view_var = True
  3291. else:
  3292. if hasattr(inputs["X"], "is_view_var"):
  3293. inputs["X"].is_view_var = True
  3294. if outputs is not None and isinstance(outputs["Out"], list):
  3295. if hasattr(outputs["Out"][0], "is_view_var"):
  3296. outputs["Out"][0].is_view_var = True
  3297. else:
  3298. if hasattr(outputs["Out"], "is_view_var"):
  3299. outputs["Out"].is_view_var = True
  3300. elif op_type == "diagonal":
  3301. if inputs is not None and isinstance(inputs["Input"], list):
  3302. if hasattr(inputs["Input"][0], "is_view_var"):
  3303. inputs["Input"][0].is_view_var = True
  3304. else:
  3305. if hasattr(inputs["Input"], "is_view_var"):
  3306. inputs["Input"].is_view_var = True
  3307. if outputs is not None and isinstance(outputs["Out"], list):
  3308. if hasattr(outputs["Out"][0], "is_view_var"):
  3309. outputs["Out"][0].is_view_var = True
  3310. else:
  3311. if hasattr(outputs["Out"], "is_view_var"):
  3312. outputs["Out"].is_view_var = True
  3313. elif op_type == "flatten":
  3314. if inputs is not None and isinstance(inputs["X"], list):
  3315. if hasattr(inputs["X"][0], "is_view_var"):
  3316. inputs["X"][0].is_view_var = True
  3317. else:
  3318. if hasattr(inputs["X"], "is_view_var"):
  3319. inputs["X"].is_view_var = True
  3320. if outputs is not None and isinstance(outputs["Out"], list):
  3321. if hasattr(outputs["Out"][0], "is_view_var"):
  3322. outputs["Out"][0].is_view_var = True
  3323. else:
  3324. if hasattr(outputs["Out"], "is_view_var"):
  3325. outputs["Out"].is_view_var = True
  3326. elif op_type == "imag":
  3327. if inputs is not None and isinstance(inputs["X"], list):
  3328. if hasattr(inputs["X"][0], "is_view_var"):
  3329. inputs["X"][0].is_view_var = True
  3330. else:
  3331. if hasattr(inputs["X"], "is_view_var"):
  3332. inputs["X"].is_view_var = True
  3333. if outputs is not None and isinstance(outputs["Out"], list):
  3334. if hasattr(outputs["Out"][0], "is_view_var"):
  3335. outputs["Out"][0].is_view_var = True
  3336. else:
  3337. if hasattr(outputs["Out"], "is_view_var"):
  3338. outputs["Out"].is_view_var = True
  3339. elif op_type == "real":
  3340. if inputs is not None and isinstance(inputs["X"], list):
  3341. if hasattr(inputs["X"][0], "is_view_var"):
  3342. inputs["X"][0].is_view_var = True
  3343. else:
  3344. if hasattr(inputs["X"], "is_view_var"):
  3345. inputs["X"].is_view_var = True
  3346. if outputs is not None and isinstance(outputs["Out"], list):
  3347. if hasattr(outputs["Out"][0], "is_view_var"):
  3348. outputs["Out"][0].is_view_var = True
  3349. else:
  3350. if hasattr(outputs["Out"], "is_view_var"):
  3351. outputs["Out"].is_view_var = True
  3352. elif op_type == "reshape" or op_type == "reshape2":
  3353. if inputs is not None and isinstance(inputs["X"], list):
  3354. if hasattr(inputs["X"][0], "is_view_var"):
  3355. inputs["X"][0].is_view_var = True
  3356. else:
  3357. if hasattr(inputs["X"], "is_view_var"):
  3358. inputs["X"].is_view_var = True
  3359. if outputs is not None and isinstance(outputs["Out"], list):
  3360. if hasattr(outputs["Out"][0], "is_view_var"):
  3361. outputs["Out"][0].is_view_var = True
  3362. else:
  3363. if hasattr(outputs["Out"], "is_view_var"):
  3364. outputs["Out"].is_view_var = True
  3365. elif op_type == "as_real":
  3366. if inputs is not None and isinstance(inputs["X"], list):
  3367. if hasattr(inputs["X"][0], "is_view_var"):
  3368. inputs["X"][0].is_view_var = True
  3369. else:
  3370. if hasattr(inputs["X"], "is_view_var"):
  3371. inputs["X"].is_view_var = True
  3372. if outputs is not None and isinstance(outputs["Out"], list):
  3373. if hasattr(outputs["Out"][0], "is_view_var"):
  3374. outputs["Out"][0].is_view_var = True
  3375. else:
  3376. if hasattr(outputs["Out"], "is_view_var"):
  3377. outputs["Out"].is_view_var = True
  3378. class Block:
  3379. """
  3380. In Fluid, a Program is consistence of multi-Block, and Block stores
  3381. VarDesc and OpDesc. In a specific Block, a VarDesc have a unique name.
  3382. One block could have some child blocks, and child block's name scopes
  3383. should inherit the parent's so that OpDesc in child block can reference
  3384. a VarDesc that is stored in the parent block.
  3385. Please reference the framework.proto for details.
  3386. Args:
  3387. program(Program): The Program that the Block belongs to.
  3388. idx(int): The block's id in the Program.
  3389. Notes:
  3390. The constructor of Block should not be invoked directly. Please
  3391. use `Program._create_block()` to create a block.
  3392. Examples:
  3393. .. code-block:: python
  3394. >>> import paddle
  3395. >>> paddle.enable_static()
  3396. >>> cur_program = paddle.static.Program()
  3397. >>> cur_block = cur_program.current_block()
  3398. >>> var = cur_block.create_var(name="X",
  3399. ... shape=[-1, 23, 48],
  3400. ... dtype='float32')
  3401. >>> cur_block.append_op(type="abs",
  3402. ... inputs={"X": [var]},
  3403. ... outputs={"Out": [var]})
  3404. """
  3405. def __init__(self, program, idx):
  3406. self.desc = program.desc.block(idx)
  3407. self.vars = collections.OrderedDict() # var_name --> var
  3408. self.ops = [] # operator list
  3409. self.program = program
  3410. def __str__(self):
  3411. return self._to_readable_code()
  3412. def _to_readable_code(self, skip_op_callstack=True):
  3413. """
  3414. Get readable debug string of Block.
  3415. .. note::
  3416. If you want to get the debug string in protobuf format,
  3417. please use :code:`to_string` method.
  3418. Args:
  3419. skip_op_callstack(bool): whether to skip parsing Operator's attribute
  3420. op_callstack, default value is True
  3421. Returns:
  3422. string: The formatted Block string.
  3423. Examples:
  3424. .. code-block:: python
  3425. >>> import paddle
  3426. >>> paddle.enable_static()
  3427. >>> cur_program = paddle.static.Program()
  3428. >>> cur_block = cur_program.current_block()
  3429. >>> new_var = cur_block.create_var(name="X",
  3430. ... shape=[-1, 23, 48],
  3431. ... dtype='float32')
  3432. >>> new_op = cur_block.append_op(type="abs",
  3433. ... inputs={"X": [new_var]},
  3434. ... outputs={"Out": [new_var]})
  3435. >>> print(cur_block._to_readable_code())
  3436. """
  3437. assert isinstance(
  3438. skip_op_callstack, bool
  3439. ), f"skip_op_callstack parameter's type is error, expect bool, received {type(skip_op_callstack)}"
  3440. block_str = f"{{ // block_idx:{self.idx} parent_idx:{self.parent_idx} forward_idx:{self.forward_block_idx} backward_idx:{self.backward_block_idx}\n"
  3441. for var in list(self.vars.values()):
  3442. block_str += f" {var._to_readable_code()}\n"
  3443. block_str += "\n"
  3444. for op in self.ops:
  3445. block_str += f" {op._to_readable_code(skip_op_callstack)}\n"
  3446. block_str += "}"
  3447. return block_str
  3448. def to_string(self, throw_on_error, with_details=False):
  3449. """
  3450. Get debug string.
  3451. Args:
  3452. throw_on_error(bool): raise exception when self is not initialized
  3453. when throw_on_error is True.
  3454. with_details(bool): more details about variables and parameters
  3455. (e.g. trainable, optimize_attr, ...) will be printed when
  3456. with_details is True. Default False.
  3457. Returns:
  3458. str: The debug string.
  3459. """
  3460. assert isinstance(throw_on_error, bool) and isinstance(
  3461. with_details, bool
  3462. )
  3463. if with_details:
  3464. re_add_indent = re.compile(r"\n(.)")
  3465. res_str = "blocks {\n idx: %d\n parent_idx: %d" % (
  3466. self.idx,
  3467. self.parent_idx,
  3468. )
  3469. for var in list(self.vars.values()):
  3470. res_str += "\n vars {\n %s }" % re_add_indent.sub(
  3471. r"\n \1", var.to_string(throw_on_error, with_details)
  3472. )
  3473. for op in self.ops:
  3474. res_str += "\n ops {\n %s }" % re_add_indent.sub(
  3475. r"\n \1", op.to_string(throw_on_error)
  3476. )
  3477. res_str += "\n}"
  3478. else:
  3479. protostr = self.desc.serialize_to_string()
  3480. proto = framework_pb2.BlockDesc.FromString(bytes(protostr))
  3481. res_str = _debug_string_(proto, throw_on_error)
  3482. return res_str
  3483. __repr__ = __str__
  3484. @property
  3485. def parent_idx(self):
  3486. return self.desc.parent
  3487. @property
  3488. def forward_block_idx(self):
  3489. return self.desc.get_forward_block_idx()
  3490. def _set_forward_block_idx(self, idx):
  3491. """
  3492. Set the forward block Idx.
  3493. Args:
  3494. idx(int): the block index.
  3495. Returns:
  3496. None
  3497. """
  3498. self.desc._set_forward_block_idx(idx)
  3499. @property
  3500. def backward_block_idx(self):
  3501. cur_block_idx = self.idx
  3502. for block in self.program.blocks:
  3503. if block.forward_block_idx == cur_block_idx:
  3504. return block.idx
  3505. return -1
  3506. @property
  3507. def idx(self):
  3508. return self.desc.id
  3509. def var(self, name):
  3510. """
  3511. Get a Variable by name from this block.
  3512. Args:
  3513. name(str): the Variable's name.
  3514. Raises:
  3515. ValueError: The If input's type is not str, or this block
  3516. doesn't have a Variable with the giving name.
  3517. Returns:
  3518. Variable: the Variable with the giving name.
  3519. """
  3520. if not isinstance(name, str):
  3521. raise TypeError(
  3522. "var require string as parameter, but get %s instead."
  3523. % (type(name))
  3524. )
  3525. v = self.vars.get(name, None)
  3526. if v is None:
  3527. raise ValueError("var %s not in this block" % name)
  3528. return v
  3529. def _find_var_recursive(self, name):
  3530. """
  3531. Get a Variable by name from this block recursively.
  3532. Args:
  3533. name(str): the Variable's name.
  3534. Returns:
  3535. Variable: the Variable with the giving name. Or None if not found.
  3536. """
  3537. frontier = []
  3538. visited = set()
  3539. frontier.append(self)
  3540. prog = self.program
  3541. while len(frontier) != 0: # BFS
  3542. cur = frontier[0]
  3543. frontier = frontier[1:]
  3544. if id(cur) in visited:
  3545. continue
  3546. if cur.has_var(name):
  3547. return cur.var(name)
  3548. if cur.parent_idx != -1:
  3549. frontier.append(prog.block(cur.parent_idx))
  3550. if cur.forward_block_idx != -1:
  3551. frontier.append(prog.block(cur.forward_block_idx))
  3552. visited.add(id(cur))
  3553. return None
  3554. def _var_recursive(self, name):
  3555. """
  3556. Get a Variable by name from this block recursively.
  3557. Args:
  3558. name(str): the Variable's name.
  3559. Raises:
  3560. ValueError: this block and this parent block doesn't
  3561. have a Variable with the giving name.
  3562. Returns:
  3563. Variable: the Variable with the giving name.
  3564. """
  3565. var = self._find_var_recursive(name)
  3566. if var:
  3567. return var
  3568. else:
  3569. raise ValueError(f"Var {name} is not found recursively")
  3570. def all_parameters(self):
  3571. return list(self.iter_parameters())
  3572. def iter_parameters(self):
  3573. return (
  3574. item[1]
  3575. for item in self.vars.items()
  3576. if isinstance(item[1], Parameter)
  3577. )
  3578. def create_var(self, *args, **kwargs):
  3579. if in_dygraph_mode():
  3580. var = _create_tensor(*args, **kwargs)
  3581. else:
  3582. var = Variable(block=self, *args, **kwargs)
  3583. if "initializer" in kwargs:
  3584. kwargs["initializer"](var, self)
  3585. return var
  3586. def has_var(self, name):
  3587. return name in self.vars
  3588. def _rename_var(self, name, new_name):
  3589. """
  3590. Rename variable in vars and ops' inputs and outputs
  3591. Args:
  3592. name(str|bytes): the name that need to be renamed.
  3593. new_name(str|bytes): the name that need to rename to.
  3594. Raises:
  3595. ValueError: If this block doesn't have this the giving name,
  3596. or the type of the var with the giving name is not Parameter
  3597. or Variable.
  3598. Returns:
  3599. Variable: the Variable with the giving name.
  3600. """
  3601. # Ensure the type of name and new_name is str
  3602. name = name.decode() if isinstance(name, bytes) else name
  3603. new_name = (
  3604. new_name.decode() if isinstance(new_name, bytes) else new_name
  3605. )
  3606. if not self.has_var(name):
  3607. raise ValueError("var %s is not in current block" % name)
  3608. v = self.var(name)
  3609. if type(v) == Parameter:
  3610. var_type = "Parameter"
  3611. stop_gradient = v.stop_gradient
  3612. trainable = v.trainable
  3613. optimize_attr = v.optimize_attr
  3614. regularizer = v.regularizer
  3615. error_clip = v.error_clip
  3616. elif type(v) == Variable:
  3617. var_type = "Variable"
  3618. error_clip = v.error_clip
  3619. stop_gradient = v.stop_gradient
  3620. else:
  3621. raise ValueError("unsupported var type: %s", type(v))
  3622. orig_var_type = v.type
  3623. self.desc._rename_var(name.encode(), new_name.encode())
  3624. # NOTE: v is destroyed by C++ after calling _rename_var.
  3625. d = self.desc.find_var(new_name.encode())
  3626. if var_type == "Parameter":
  3627. if in_dygraph_mode():
  3628. var = EagerParamBase(
  3629. d.shape(),
  3630. d.dtype(),
  3631. type=orig_var_type,
  3632. name=new_name,
  3633. stop_gradient=stop_gradient,
  3634. trainable=trainable,
  3635. optimize_attr=optimize_attr,
  3636. regularizer=regularizer,
  3637. error_clip=error_clip,
  3638. )
  3639. else:
  3640. var = Parameter(
  3641. self,
  3642. d.shape(),
  3643. d.dtype(),
  3644. type=orig_var_type,
  3645. name=new_name,
  3646. stop_gradient=stop_gradient,
  3647. trainable=trainable,
  3648. optimize_attr=optimize_attr,
  3649. regularizer=regularizer,
  3650. error_clip=error_clip,
  3651. )
  3652. elif var_type == "Variable":
  3653. var = Variable(
  3654. self,
  3655. type=orig_var_type,
  3656. name=new_name,
  3657. error_clip=error_clip,
  3658. stop_gradient=stop_gradient,
  3659. )
  3660. # rename the python side, _sync_with_cpp will only add
  3661. # new vars/ops to python side.
  3662. self.vars[new_name] = var
  3663. del self.vars[name]
  3664. self._sync_with_cpp()
  3665. return var
  3666. def _remove_var(self, name, sync=True):
  3667. if sync is True:
  3668. self._sync_with_cpp()
  3669. self.desc._remove_var(name.encode())
  3670. del self.vars[name]
  3671. def create_parameter(self, *args, **kwargs):
  3672. global_block = self.program.global_block()
  3673. param = None
  3674. if in_dygraph_mode():
  3675. param = EagerParamBase(*args, **kwargs)
  3676. else:
  3677. param = Parameter(global_block, *args, **kwargs)
  3678. # NOTE(Aurelius84): we deliver stop_gradient in append_op, so we
  3679. # need record it state and reset it back after calling this API
  3680. stop_gradient = param.stop_gradient
  3681. if "initializer" in kwargs:
  3682. def _is_inited_by(block, var):
  3683. init_ops = []
  3684. for op in block.ops:
  3685. if var.name in op.output_arg_names:
  3686. # In startup_program, "c_broadcast" and "c_sync_comm_stream"
  3687. # are treated as initialization ops that cause error.
  3688. # Think of "c_broadcast" and "c_sync_comm_stream" as a special case here.
  3689. # NOTE: "coalesce_tensor" is a special case for rnn with cudnn support
  3690. if op.type in [
  3691. "c_broadcast",
  3692. "c_sync_comm_stream",
  3693. "coalesce_tensor",
  3694. ]:
  3695. continue
  3696. init_ops.append(op)
  3697. return init_ops
  3698. initializer = kwargs["initializer"]
  3699. init_ops = _is_inited_by(global_block, param)
  3700. init_ops_len = len(init_ops)
  3701. if init_ops_len > 1:
  3702. raise RuntimeError(
  3703. "param "
  3704. + param.name
  3705. + " is inited by multiple init ops "
  3706. + str(init_ops)
  3707. )
  3708. elif init_ops_len == 1:
  3709. # TODO already inited, do nothing, should log a warning
  3710. pass
  3711. else:
  3712. initializer(param, self)
  3713. param.stop_gradient = stop_gradient
  3714. return param
  3715. def append_op(self, *args, **kwargs):
  3716. """
  3717. Appends a new Operator according to the giving arguments.
  3718. Returns:
  3719. Operator: the append Operator.
  3720. """
  3721. inplace_map = kwargs.get("inplace_map", None)
  3722. op_type = kwargs.get("type", None)
  3723. if in_dygraph_mode():
  3724. attrs = kwargs.get("attrs", {})
  3725. warnings.warn(
  3726. "Op `%s` is executed through `append_op` under the dynamic mode, "
  3727. "the corresponding API implementation needs to be upgraded to "
  3728. "using `_C_ops` method." % type,
  3729. DeprecationWarning,
  3730. )
  3731. op = Operator(
  3732. block=self,
  3733. desc=None,
  3734. type=op_type,
  3735. inputs=None,
  3736. outputs=None,
  3737. attrs=attrs,
  3738. )
  3739. # record ops in tracer rather than blocks
  3740. #
  3741. # TODO(minqiyang): add op stop_gradient support in static graph mode too.
  3742. # currently, we only support stop_gradient in dygraph mode.
  3743. _dygraph_tracer().trace_op(
  3744. op_type,
  3745. kwargs.get("inputs", {}),
  3746. kwargs.get("outputs", {}),
  3747. attrs if attrs else {},
  3748. kwargs.get("stop_gradient", False),
  3749. inplace_map,
  3750. )
  3751. else:
  3752. from paddle.base.dygraph.base import param_guard
  3753. from paddle.utils import flatten
  3754. def pass_stop_gradient(ins, outs):
  3755. """
  3756. Set out.stop_gradient = True if all inputs stop_gradient is True.
  3757. """
  3758. need_reset = True
  3759. for var in flatten(ins):
  3760. if getattr(var, "stop_gradient", None) is False:
  3761. need_reset = False
  3762. break
  3763. if need_reset:
  3764. for var in flatten(outs):
  3765. if isinstance(var, Variable):
  3766. var.stop_gradient = True
  3767. op_desc = self.desc.append_op()
  3768. inputs = kwargs.get("inputs", None)
  3769. outputs = kwargs.get("outputs", None)
  3770. # NOTE(Aurelius84): In case of @to_static, all Tensor(s) should
  3771. # be converted into Variable(s) with same name and block location.
  3772. # This is ONE and ONLY logic of type transformation of dy2static.
  3773. ignore_ops = {
  3774. "conditional_block",
  3775. "conditional_block_grad",
  3776. "pylayer",
  3777. "pylayer_grad",
  3778. "recurrent",
  3779. "recurrent_grad",
  3780. "while",
  3781. "while_grad",
  3782. }
  3783. from .dygraph.base import in_to_static_mode
  3784. if in_to_static_mode() and not _stride_in_no_check_dy2st_diff_mode:
  3785. check_if_to_static_diff_with_dygraph(
  3786. op_type, inplace_map, outputs
  3787. )
  3788. if op_type not in ignore_ops:
  3789. pass_stop_gradient(inputs, outputs)
  3790. with param_guard(inputs), param_guard(outputs):
  3791. op = Operator(
  3792. block=self,
  3793. desc=op_desc,
  3794. type=op_type,
  3795. inputs=inputs,
  3796. outputs=outputs,
  3797. attrs=kwargs.get("attrs", None),
  3798. )
  3799. self.ops.append(op)
  3800. if in_to_static_mode():
  3801. record_is_view_var(op_type, inputs, outputs)
  3802. return op
  3803. def _insert_op(self, index, *args, **kwargs):
  3804. """
  3805. Insert a Operator according to the giving arguments.
  3806. Args:
  3807. index(int): the place that the operator to insert.
  3808. Returns:
  3809. Operator: the insert Operator.
  3810. """
  3811. self._sync_with_cpp()
  3812. return self._insert_op_without_sync(index, *args, **kwargs)
  3813. def _insert_op_without_sync(self, index, *args, **kwargs):
  3814. """
  3815. Insert an Operator according to the giving arguments,
  3816. without sync_with_cpp to meke the compilation faster.
  3817. Args:
  3818. index(int): the place that the operator to insert.
  3819. Returns:
  3820. Operator: the insert Operator.
  3821. """
  3822. op_desc = self.desc._insert_op(index)
  3823. op = Operator(block=self, desc=op_desc, *args, **kwargs)
  3824. self.ops.insert(index, op)
  3825. return op
  3826. def _remove_op(self, index, sync=True):
  3827. """
  3828. Remove the specific position operator.
  3829. Args:
  3830. index(int): the position that the operator to insert.
  3831. Returns:
  3832. None
  3833. """
  3834. if sync is True:
  3835. self._sync_with_cpp()
  3836. self.desc._remove_op(index, index + 1)
  3837. del self.ops[index]
  3838. def _slice_ops(self, start, end):
  3839. """
  3840. Return the Operator between start and end.
  3841. Args:
  3842. start(int): the start position.
  3843. end(int): the end position.
  3844. Returns:
  3845. list: the Operators between start and end.
  3846. """
  3847. return self.ops[start:end]
  3848. def _prepend_op(self, *args, **kwargs):
  3849. if in_dygraph_mode():
  3850. type = kwargs.get("type", None)
  3851. attrs = kwargs.get("attrs", {})
  3852. op = Operator(
  3853. self, None, type=type, inputs=None, outputs=None, attrs=attrs
  3854. )
  3855. _dygraph_tracer().trace_op(
  3856. type,
  3857. kwargs.get("inputs", {}),
  3858. kwargs.get("outputs", {}),
  3859. attrs if attrs else {},
  3860. kwargs.get("stop_gradient", False),
  3861. )
  3862. else:
  3863. op_desc = self.desc._prepend_op()
  3864. op = Operator(
  3865. self,
  3866. op_desc,
  3867. type=kwargs.get("type", None),
  3868. inputs=kwargs.get("inputs", None),
  3869. outputs=kwargs.get("outputs", None),
  3870. attrs=kwargs.get("attrs", None),
  3871. )
  3872. self.ops.insert(0, op)
  3873. return op
  3874. def _sync_with_cpp(self):
  3875. """
  3876. Sync from the desc on the c++ end. This method is used to synchronize
  3877. the c++ desc instance generated by backward.
  3878. """
  3879. # sync variables from cpp
  3880. for var in self.desc.all_vars():
  3881. if not self.has_var(var.name()):
  3882. is_stop_gradient = False
  3883. if var.has_stop_gradient():
  3884. is_stop_gradient = var.stop_gradient()
  3885. if var.has_is_parameter() and var.is_parameter():
  3886. self.create_parameter(
  3887. name=var.name(),
  3888. desc=var,
  3889. type=var.type(),
  3890. shape=var.shape(),
  3891. dtype=var.dtype(),
  3892. stop_gradient=is_stop_gradient,
  3893. )
  3894. else:
  3895. self.create_var(
  3896. name=var.name(),
  3897. desc=var,
  3898. type=var.type(),
  3899. stop_gradient=is_stop_gradient,
  3900. )
  3901. # sync variables removed from c++ end
  3902. for var in list(self.vars.keys()):
  3903. if not self.desc.find_var(var.encode()):
  3904. self.vars.pop(var)
  3905. # sync operators from cpp
  3906. ops_in_cpp = []
  3907. for op_idx in range(0, self.desc.op_size()):
  3908. ops_in_cpp.append(self.desc.op(op_idx))
  3909. if len(self.ops) != 0:
  3910. first_op_in_python = self.ops[0].desc
  3911. last_op_in_python = self.ops[len(self.ops) - 1].desc
  3912. start_index = None
  3913. end_index = None
  3914. for index in range(len(ops_in_cpp)):
  3915. if first_op_in_python == ops_in_cpp[index]:
  3916. start_index = index
  3917. if last_op_in_python == ops_in_cpp[index]:
  3918. end_index = index
  3919. assert start_index is not None
  3920. assert end_index is not None
  3921. assert start_index <= end_index
  3922. else:
  3923. start_index = 0
  3924. end_index = -1
  3925. # sync ops append to the head of cpp_ops
  3926. for index in range((start_index - 1 - 1), -1, -1):
  3927. op_desc = ops_in_cpp[index]
  3928. op = Operator(self, op_desc)
  3929. self.ops.insert(0, op)
  3930. # sync ops append to the end of cpp_ops
  3931. for index in range((end_index + 1), len(ops_in_cpp)):
  3932. op_desc = ops_in_cpp[index]
  3933. op = Operator(self, op_desc)
  3934. self.ops.append(op)
  3935. # sync ops removed from c++ end
  3936. if end_index != -1 and end_index < len(self.ops):
  3937. ops_in_cpp_index = 0
  3938. ops_in_python_index = 0
  3939. while ops_in_python_index < len(
  3940. self.ops
  3941. ) and ops_in_cpp_index < len(ops_in_cpp):
  3942. if (
  3943. self.ops[ops_in_python_index].desc
  3944. != ops_in_cpp[ops_in_cpp_index]
  3945. ):
  3946. del self.ops[ops_in_python_index]
  3947. else:
  3948. ops_in_cpp_index += 1
  3949. ops_in_python_index += 1
  3950. assert len(self.ops) == len(ops_in_cpp)
  3951. for index in range(len(self.ops)):
  3952. assert self.ops[index].desc == ops_in_cpp[index]
  3953. def _copy_param_info_from(self, other):
  3954. """
  3955. Copy the information of parameters from the other block.
  3956. Args:
  3957. other(Block): the other block.
  3958. Raises:
  3959. ValueError: If type of input is not Block, or the `other` and this
  3960. block is not in the same topology.
  3961. Returns:
  3962. None
  3963. """
  3964. if not isinstance(other, Block):
  3965. raise TypeError(
  3966. "_copy_param_info_from should be invoked with Block"
  3967. )
  3968. for p in other.iter_parameters():
  3969. assert isinstance(p, Parameter)
  3970. v = self.vars.get(p.name, None)
  3971. if v is None:
  3972. # if the Parameter is pruned, v may be None
  3973. continue
  3974. assert isinstance(v, Variable)
  3975. new_p = None
  3976. if in_dygraph_mode():
  3977. new_p = EagerParamBase(
  3978. shape=v.shape,
  3979. dtype=v.dtype,
  3980. type=v.type,
  3981. lod_level=v.lod_level,
  3982. stop_gradient=p.stop_gradient,
  3983. trainable=p.trainable,
  3984. optimize_attr=p.optimize_attr,
  3985. regularizer=p.regularizer,
  3986. error_clip=p.error_clip,
  3987. name=v.name,
  3988. )
  3989. else:
  3990. new_p = Parameter(
  3991. block=self,
  3992. shape=v.shape,
  3993. dtype=v.dtype,
  3994. type=v.type,
  3995. lod_level=v.lod_level
  3996. if v.type == core.VarDesc.VarType.LOD_TENSOR
  3997. else None,
  3998. stop_gradient=p.stop_gradient,
  3999. trainable=p.trainable,
  4000. optimize_attr=p.optimize_attr,
  4001. regularizer=p.regularizer,
  4002. error_clip=p.error_clip,
  4003. name=v.name,
  4004. )
  4005. self.vars[new_p.name] = new_p
  4006. def _clone_variable(self, var, force_persistable=True):
  4007. """
  4008. Clone a variable into current block.
  4009. Args:
  4010. var: the variable to be cloned.
  4011. force_persistable(bool): True means setting the result variable to being persistable.
  4012. False means setting the persistable the same with that of input var.
  4013. default: True.
  4014. Returns:
  4015. Variable: the new variable cloned from 'var' in current block.
  4016. """
  4017. assert isinstance(var, Variable)
  4018. ret_var = None
  4019. # make STEP_SCOPES var can be safely cloned.
  4020. if var.type == core.VarDesc.VarType.STEP_SCOPES:
  4021. ret_var = self.create_var(
  4022. name=var.name, persistable=var.persistable, type=var.type
  4023. )
  4024. elif var.type == core.VarDesc.VarType.RAW:
  4025. ret_var = self.create_var(
  4026. name=var.name, persistable=var.persistable, type=var.type
  4027. )
  4028. elif var.type == core.VarDesc.VarType.SELECTED_ROWS:
  4029. ret_var = self.create_var(
  4030. name=var.name,
  4031. shape=var.shape,
  4032. dtype=var.dtype,
  4033. type=var.type,
  4034. persistable=True if force_persistable else var.persistable,
  4035. is_data=var.is_data,
  4036. need_check_feed=var.desc.need_check_feed(),
  4037. )
  4038. else:
  4039. ret_var = self.create_var(
  4040. name=var.name,
  4041. shape=var.shape,
  4042. dtype=var.dtype,
  4043. type=var.type,
  4044. lod_level=var.lod_level,
  4045. persistable=True if force_persistable else var.persistable,
  4046. is_data=var.is_data,
  4047. need_check_feed=var.desc.need_check_feed(),
  4048. )
  4049. return ret_var
  4050. # NOTE(zjl): you should be careful that after you call this method,
  4051. # some Python Variable and all Python Operators should not be used
  4052. # again. Because all Python Variables and all Python Operators are
  4053. # re-constructed inside this method. The underlying VarDesc(OpDesc)
  4054. # of some old Python Variables(all old Python Operators) may have
  4055. # been destructed.
  4056. def _apply_pass(
  4057. main_program, startup_program, pass_name, pass_attrs={}, pass_attr_types={}
  4058. ):
  4059. assert isinstance(pass_attrs, dict), "pass_attrs must be dict"
  4060. assert isinstance(pass_attr_types, dict), "pass_attr_types must be dict"
  4061. tmp_main_program = core.ProgramDesc(main_program.desc)
  4062. tmp_startup_program = core.ProgramDesc(startup_program.desc)
  4063. attrs = core.apply_pass(
  4064. tmp_main_program,
  4065. tmp_startup_program,
  4066. pass_name,
  4067. pass_attrs,
  4068. pass_attr_types,
  4069. )
  4070. main_program._rebuild_from_desc(tmp_main_program)
  4071. startup_program._rebuild_from_desc(tmp_startup_program)
  4072. return attrs
  4073. class IrNode:
  4074. """
  4075. Python IrNode. Beneath it is a core.Node, which is used for Ir Pass.
  4076. """
  4077. def __init__(self, node):
  4078. """
  4079. Construct an IrNode using core.Node.
  4080. Args:
  4081. node(core.Node): C++ Node.
  4082. """
  4083. assert isinstance(
  4084. node, core.Node
  4085. ), "node must be the instance of core.Node."
  4086. self.node = node
  4087. def name(self):
  4088. """
  4089. Return the node name.
  4090. Returns:
  4091. str: node name.
  4092. """
  4093. return self.node.name()
  4094. def node_type(self):
  4095. """
  4096. Return the node type.
  4097. Returns:
  4098. core.Node.Type: node type(core.Node.Type.Operation or core.Node.Type.Variable).
  4099. """
  4100. return self.node.node_type()
  4101. def var(self):
  4102. """
  4103. Return the node variable description.
  4104. Returns:
  4105. core.VarDesc: node variable description.
  4106. """
  4107. return self.node.var()
  4108. def op(self):
  4109. """
  4110. Return the node operator description.
  4111. Returns:
  4112. core.OpDesc: node operator description.
  4113. """
  4114. return self.node.op()
  4115. def id(self):
  4116. """
  4117. Return the node id.
  4118. Returns:
  4119. int: node id.
  4120. """
  4121. return self.node.id()
  4122. def is_op(self):
  4123. """
  4124. If the node is an operator, then return true.
  4125. Returns:
  4126. bool: indicate whether the node is an operator.
  4127. """
  4128. return self.node.is_op()
  4129. def is_var(self):
  4130. """
  4131. If the node is a variable, then return true.
  4132. Returns:
  4133. bool: indicate whether the node is a variable.
  4134. """
  4135. return self.node.is_var()
  4136. def is_ctrl_var(self):
  4137. """
  4138. If the node is a control dependence variable, then return true.
  4139. Returns:
  4140. bool: indicate whether the node is a control dependence variable.
  4141. """
  4142. return self.node.is_ctrl_var()
  4143. def clear_inputs(self):
  4144. """
  4145. Clear the node inputs. After executing the `clear_inputs` function,
  4146. the node inputs will be empty.
  4147. """
  4148. self.node.clear_inputs()
  4149. def remove_input_by_id(self, node_id):
  4150. """
  4151. Remove a node from inputs by the given node id.
  4152. Args:
  4153. node_id(int): the given node id.
  4154. """
  4155. self.node.remove_input(node_id)
  4156. def remove_input(self, node):
  4157. """
  4158. Remove a node from inputs.
  4159. Args:
  4160. node(IrNode): the node being removed.
  4161. """
  4162. self.node.remove_input(node.node)
  4163. def append_input(self, node):
  4164. """
  4165. Append a node in inputs.
  4166. Args:
  4167. node(IrNode): the node being appended.
  4168. """
  4169. self.node.append_input(node.node)
  4170. def clear_outputs(self):
  4171. """
  4172. Clear the node outputs. After executing the `clear_outputs` function,
  4173. the node outputs will be empty.
  4174. """
  4175. self.node.clear_outputs()
  4176. def remove_output_by_id(self, node_id):
  4177. """
  4178. Remove a node from outputs by the given node id.
  4179. Args:
  4180. node_id(int): the given node id.
  4181. """
  4182. self.node.remove_output(node_id)
  4183. def remove_output(self, node):
  4184. """
  4185. Remove a node from outputs.
  4186. Args:
  4187. node(IrNode): the node being removed.
  4188. """
  4189. self.node.remove_output(node.node)
  4190. def append_output(self, node):
  4191. """
  4192. Append a node in outputs.
  4193. Args:
  4194. node(IrNode): the node being appended.
  4195. """
  4196. self.node.append_output(node.node)
  4197. @property
  4198. def inputs(self):
  4199. """
  4200. Return the node inputs.
  4201. Returns:
  4202. list(IrNode): node inputs wrapped by IrNode.
  4203. """
  4204. return [IrNode(n) for n in self.node.inputs]
  4205. @property
  4206. def outputs(self):
  4207. """
  4208. Return the node outputs.
  4209. Returns:
  4210. list(IrNode): node outputs wrapped by IrNode.
  4211. """
  4212. return [IrNode(n) for n in self.node.outputs]
  4213. class IrVarNode(IrNode):
  4214. """
  4215. Python IrVarNode. Beneath it is a core.Node, it inherits from IrNode.
  4216. """
  4217. def __init__(self, node):
  4218. """
  4219. Construct an IrVarNode using core.Node.
  4220. Args:
  4221. node(core.Node): C++ Node.
  4222. """
  4223. assert (
  4224. isinstance(node, core.Node) and node.is_var()
  4225. ), "node must be the instance of core.Node and it must be a variable node."
  4226. super().__init__(node)
  4227. self.node = node
  4228. def set_shape(self, shape):
  4229. """
  4230. Set the node variable shape.
  4231. Args:
  4232. shape(list): shape to be set.
  4233. """
  4234. assert (
  4235. self.node.var() is not None
  4236. ), "The node variable description can not be None."
  4237. self.node.var().set_shape(shape)
  4238. def persistable(self):
  4239. """
  4240. If the variable node is a persistable variable, then return true.
  4241. Returns:
  4242. bool: indicate whether the variable is persistable.
  4243. """
  4244. assert (
  4245. self.node.var() is not None
  4246. ), "The node variable description can not be None."
  4247. return self.node.var().persistable()
  4248. def type(self):
  4249. """
  4250. Return the variable type.
  4251. Returns:
  4252. core.VarDesc.VarType: the variable type.
  4253. """
  4254. assert (
  4255. self.node.var() is not None
  4256. ), "The node variable description can not be None."
  4257. return self.node.var().type()
  4258. def dtype(self):
  4259. """
  4260. Return the variable data type.
  4261. Returns:
  4262. core.VarDesc.VarType: the variable data type.
  4263. """
  4264. assert (
  4265. self.node.var() is not None
  4266. ), "The node variable description can not be None."
  4267. return self.node.var().dtype()
  4268. def shape(self):
  4269. """
  4270. Return the variable shape.
  4271. Returns:
  4272. list: the variable shape.
  4273. """
  4274. assert (
  4275. self.node.var() is not None
  4276. ), "The node variable description can not be None."
  4277. return self.node.var().shape()
  4278. @property
  4279. def inputs(self):
  4280. """
  4281. Return the node inputs.
  4282. Returns:
  4283. list(IrOpNode): node inputs wrapped by IrOpNode.
  4284. """
  4285. return [IrOpNode(n) for n in self.node.inputs]
  4286. @property
  4287. def outputs(self):
  4288. """
  4289. Return the node outputs.
  4290. Returns:
  4291. list(IrOpNode): node outputs wrapped by IrOpNode.
  4292. """
  4293. return [IrOpNode(n) for n in self.node.outputs]
  4294. class IrOpNode(IrNode):
  4295. """
  4296. Python IrOpNode. Beneath it is a core.Node, it inherits from IrNode.
  4297. """
  4298. def __init__(self, node):
  4299. """
  4300. Construct an IrOpNode using core.Node.
  4301. Args:
  4302. node(core.Node): C++ Node.
  4303. """
  4304. assert (
  4305. isinstance(node, core.Node) and node.is_op()
  4306. ), "node must be the instance of core.Node and it must be a operator node."
  4307. super().__init__(node)
  4308. self.node = node
  4309. def rename_input(self, old_input_name, new_input_name):
  4310. """
  4311. Rename the input of this node.
  4312. Args:
  4313. old_input_name(str): the old input name.
  4314. new_input_name(str): the new input name.
  4315. """
  4316. assert (
  4317. self.node.op() is not None
  4318. ), "The node operator description can not be None."
  4319. self.node.op()._rename_input(old_input_name, new_input_name)
  4320. def rename_output(self, old_output_name, new_output_name):
  4321. """
  4322. Rename the output of this node.
  4323. Args:
  4324. old_output_name(str): the old output name.
  4325. new_output_name(str): the new output name.
  4326. """
  4327. assert (
  4328. self.node.op() is not None
  4329. ), "The node operator description can not be None."
  4330. self.node.op()._rename_output(old_output_name, new_output_name)
  4331. def input(self, name):
  4332. """
  4333. Get the argument name list by the parameter name for input.
  4334. Args:
  4335. name(str): the parameter name.
  4336. Returns:
  4337. list(str): the argument name list.
  4338. """
  4339. assert (
  4340. self.node.op() is not None
  4341. ), "The node operator description can not be None."
  4342. return self.node.op().input(name)
  4343. def output(self, name):
  4344. """
  4345. Get the argument name list by the parameter name for output.
  4346. Args:
  4347. name(str): the parameter name.
  4348. Returns:
  4349. list(str): the argument name list.
  4350. """
  4351. assert (
  4352. self.node.op() is not None
  4353. ), "The node operator description can not be None."
  4354. return self.node.op().output(name)
  4355. def set_type(self, new_type):
  4356. """
  4357. Change the operator type into new type.
  4358. Args:
  4359. new_type(str): new operator type to be set.
  4360. """
  4361. assert (
  4362. self.node.op() is not None
  4363. ), "The node operator description can not be None."
  4364. return self.node.op().set_type(new_type)
  4365. def set_attr(self, name, val):
  4366. """
  4367. Set the value of attribute by attribute's name.
  4368. Args:
  4369. name(str): the attribute name.
  4370. val(bool|int|str|float|list): the value of the attribute.
  4371. """
  4372. self._update_desc_attr(name, val)
  4373. def _update_desc_attr(self, name, val):
  4374. """
  4375. Update the value of the op desc's attribute by attribute's name.
  4376. """
  4377. assert (
  4378. self.node.op() is not None
  4379. ), "The node operator description can not be None."
  4380. desc = self.node.op()
  4381. if isinstance(val, Variable):
  4382. desc.set_var_attr(name, val.desc)
  4383. elif isinstance(val, list) and _all_is_type(val, Variable):
  4384. desc.set_vars_attr(name, [v.desc for v in val])
  4385. elif isinstance(val, Block):
  4386. desc.set_block_attr(name, val.desc)
  4387. elif isinstance(val, list) and val and _all_is_type(val, Block):
  4388. desc.set_blocks_attr(name, [v.desc for v in val])
  4389. elif isinstance(val, (core.BlockDesc, core.ProgramDesc)):
  4390. desc.set_serialized_attr(name, val.serialize_to_string())
  4391. else:
  4392. desc._set_attr(name, val)
  4393. def input_arg_names(self):
  4394. """
  4395. Return input arguments' names of this op node.
  4396. Returns:
  4397. list(str): input arguments' names of this op node.
  4398. """
  4399. assert (
  4400. self.node.op() is not None
  4401. ), "The node operator description can not be None."
  4402. return self.node.op().input_arg_names()
  4403. def output_arg_names(self):
  4404. """
  4405. Return output arguments' names of this op node.
  4406. Returns:
  4407. list(str): output arguments' names of this op node.
  4408. """
  4409. assert (
  4410. self.node.op() is not None
  4411. ), "The node operator description can not be None."
  4412. return self.node.op().output_arg_names()
  4413. @property
  4414. def inputs(self):
  4415. """
  4416. Return the node inputs.
  4417. Returns:
  4418. list(IrVarNode): node inputs wrapped by IrVarNode.
  4419. """
  4420. return [IrVarNode(n) for n in self.node.inputs]
  4421. @property
  4422. def outputs(self):
  4423. """
  4424. Return the node outputs.
  4425. Returns:
  4426. list(IrVarNode): node outputs wrapped by IrVarNode.
  4427. """
  4428. return [IrVarNode(n) for n in self.node.outputs]
  4429. class IrGraph:
  4430. """
  4431. Python IrGraph. Beneath it is a core.Graph, which is used for
  4432. creating a c++ Ir Pass Graph. An IrGraph is just a graph view of
  4433. a Program. In an IrGraph, both Variables and Operators are graph
  4434. nodes.
  4435. """
  4436. def __init__(self, graph, for_test=False):
  4437. """
  4438. Construct an IrGraph using core.Graph.
  4439. Args:
  4440. graph(core.Graph): C++ Graph.
  4441. for_test(bool): True for the test graph and false for the train graph.
  4442. """
  4443. assert isinstance(
  4444. graph, core.Graph
  4445. ), "graph must be the instance of core.Graph."
  4446. self.graph = graph
  4447. self._for_test = for_test
  4448. def clone(self):
  4449. """
  4450. Create a new and duplicated IrGraph.
  4451. Warns:
  4452. The method only clones the graph structure, not its attributes.
  4453. Returns:
  4454. IrGraph: A new and duplicated graph.
  4455. """
  4456. g = self.graph.clone()
  4457. return IrGraph(g, self._for_test)
  4458. def is_test(self):
  4459. """
  4460. If the graph is used for testing, the function returns true. Otherwise, returns false.
  4461. """
  4462. return self._for_test
  4463. def all_nodes(self):
  4464. """
  4465. Return all nodes included in the graph as a set.
  4466. """
  4467. return {IrNode(node) for node in self.graph.nodes()}
  4468. def all_var_nodes(self):
  4469. """
  4470. Return all variable nodes included in the graph as a set.
  4471. """
  4472. return {IrVarNode(node) for node in self.graph.nodes() if node.is_var()}
  4473. def all_persistable_nodes(self):
  4474. """
  4475. Return all persistable variable nodes included in the graph as a set.
  4476. """
  4477. persistable_nodes = set()
  4478. for node in self.graph.nodes():
  4479. if (
  4480. node.is_var()
  4481. and node.var() is not None
  4482. and node.var().persistable()
  4483. ):
  4484. persistable_nodes.add(node)
  4485. return {IrVarNode(p) for p in persistable_nodes}
  4486. def all_op_nodes(self):
  4487. """
  4488. Return all operator nodes included in the graph as a set.
  4489. """
  4490. return {IrOpNode(node) for node in self.graph.nodes() if node.is_op()}
  4491. def all_sub_graphs(self, for_test=False):
  4492. """
  4493. Return all sub_graphs included in the main graph as a set.
  4494. """
  4495. return [
  4496. IrGraph(self.graph.get_sub_graph(i), for_test=for_test)
  4497. for i in range(self.graph.sub_graph_size())
  4498. ]
  4499. def get_sub_graph(self, i, for_test=False):
  4500. """
  4501. Return i-th sub_graph in the main graph.
  4502. """
  4503. return IrGraph(self.graph.get_sub_graph(i), for_test=for_test)
  4504. def create_persistable_node(self, name, var_type, shape, var_dtype):
  4505. """
  4506. Create a persistable variable node in the graph. In IrGraph,
  4507. it can not distinguish between persistable variables and parameters.
  4508. Args:
  4509. name(str): the name of the persistable variable node.
  4510. vart_type(core.VarDesc.VarType): the type of the persistable variable node.
  4511. shape(list): the shape of the persistable variable node.
  4512. var_dtype(core.VarDesc.VarType): the data type of the persistable variable node.
  4513. Returns:
  4514. IrVarNode: the created persistable variable node.
  4515. """
  4516. var_desc = core.VarDesc(name)
  4517. var_desc.set_type(var_type)
  4518. var_desc.set_shape(shape)
  4519. var_desc.set_dtype(var_dtype)
  4520. var_desc.set_persistable(True)
  4521. return IrVarNode(self.graph.create_var_node(var_desc))
  4522. def create_var_node(self, name, var_type, shape, var_dtype):
  4523. """
  4524. Create a variable node in the graph. The created variable node is
  4525. not persistable.
  4526. Args:
  4527. name(str): the name of the variable node.
  4528. vart_type(core.VarDesc.VarType): the type of the variable node.
  4529. shape(list): the shape of the variable node.
  4530. var_dtype(core.VarDesc.VarType): the data type of the variable node.
  4531. Returns:
  4532. IrVarNode: the created variable node.
  4533. """
  4534. var_desc = core.VarDesc(name)
  4535. var_desc.set_type(var_type)
  4536. var_desc.set_shape(shape)
  4537. var_desc.set_dtype(var_dtype)
  4538. return IrVarNode(self.graph.create_var_node(var_desc))
  4539. def create_control_dep_var(self):
  4540. """
  4541. create a control var
  4542. """
  4543. return IrVarNode(self.graph.create_control_dep_var())
  4544. def create_var_node_from_desc(self, var_desc):
  4545. """
  4546. Create a variable node by using an existing VarDesc in the graph.
  4547. Depend on the giving VarDesc, the created variable node may be persistable.
  4548. Args:
  4549. var_desc(core.VarDesc): the giving variable description.
  4550. Returns:
  4551. IrVarNode: the created variable node.
  4552. """
  4553. return IrVarNode(self.graph.create_var_node(var_desc))
  4554. def create_op_node(self, op_type, attrs, inputs, outputs):
  4555. """
  4556. Create a operator node in the graph.
  4557. Args:
  4558. op_type(str): the type of the operator node.
  4559. attrs(dict): the attributes of the operator node.
  4560. inputs(dict): the inputs of the operator node.
  4561. outputs(dict): the outputs of the operator node.
  4562. Returns:
  4563. IrOpNode: the created operator node.
  4564. """
  4565. op_desc = core.OpDesc()
  4566. op_desc.set_type(op_type)
  4567. for attr, value in attrs.items():
  4568. self._update_desc_attr(op_desc, attr, value)
  4569. for input_name, var_nodes in inputs.items():
  4570. if not isinstance(var_nodes, list):
  4571. var_nodes = [var_nodes]
  4572. op_desc.set_input(
  4573. input_name, [var_node.name() for var_node in var_nodes]
  4574. )
  4575. for output_name, var_nodes in outputs.items():
  4576. if not isinstance(var_nodes, list):
  4577. var_nodes = [var_nodes]
  4578. op_desc.set_output(
  4579. output_name, [var_node.name() for var_node in var_nodes]
  4580. )
  4581. return IrOpNode(self.graph.create_op_node(op_desc))
  4582. def create_op_node_from_desc(self, op_desc):
  4583. """
  4584. Create a operator node by using an existing OpDesc in the graph.
  4585. Args:
  4586. op_desc(core.VarDesc): the giving operator description.
  4587. Returns:
  4588. IrOpNode: the created operator node.
  4589. """
  4590. return IrOpNode(self.graph.create_op_node(op_desc))
  4591. def update_input_link(self, old_input_node, new_input_node, op_node):
  4592. """
  4593. Update the input's link of a operator node.
  4594. Args:
  4595. old_input_node(IrNode): the old input node of the giving op_node.
  4596. new_input_node(IrNode): the new input node of the giving op_node.
  4597. op_node(IrOpNode): the operator node that is needed to update input's link.
  4598. """
  4599. assert (
  4600. old_input_node.node in self.graph.nodes()
  4601. and new_input_node.node in self.graph.nodes()
  4602. and op_node.node in self.graph.nodes()
  4603. ), "The three arguments(old_input_node&new_input_node&op_node) must be in the graph nodes."
  4604. old_input_node.remove_output(op_node)
  4605. op_node.remove_input(old_input_node)
  4606. new_input_node.append_output(op_node)
  4607. op_node.append_input(new_input_node)
  4608. op_node.rename_input(old_input_node.name(), new_input_node.name())
  4609. def update_output_link(self, old_output_node, new_output_node, op_node):
  4610. """
  4611. Update the output's link of an operator node.
  4612. Args:
  4613. old_output_node(IrNode): the old output node of the giving op_node.
  4614. new_output_node(IrNode): the new output node of the giving op_node.
  4615. op_node(IrOpNode): the operator node that is needed to update input's link.
  4616. """
  4617. assert (
  4618. old_output_node.node in self.graph.nodes()
  4619. and new_output_node.node in self.graph.nodes()
  4620. and op_node.node in self.graph.nodes()
  4621. ), "The three arguments(old_output_node &new_output_node &op_node) must be in the graph nodes."
  4622. old_output_node.remove_input(op_node)
  4623. op_node.remove_output(old_output_node)
  4624. new_output_node.append_input(op_node)
  4625. op_node.append_output(new_output_node)
  4626. op_node.rename_output(old_output_node.name(), new_output_node.name())
  4627. def link_to(self, node_in, node_out):
  4628. """
  4629. Connect two nodes.
  4630. Args:
  4631. node_in(IrNode): the input node.
  4632. node_out(IrNode): the output node.
  4633. """
  4634. assert node_in.node in self.graph.nodes(), (
  4635. "node_in(%s) must be in the graph nodes." % node_in.node.name()
  4636. )
  4637. assert node_out.node in self.graph.nodes(), (
  4638. "node_out(%s) must be in the graph nodes." % node_out.node.name()
  4639. )
  4640. node_in.append_output(node_out)
  4641. node_out.append_input(node_in)
  4642. def safe_remove_nodes(self, remove_nodes):
  4643. """
  4644. Remove nodes safely since links connected to these removed nodes are
  4645. also removed.
  4646. Args:
  4647. remove_nodes(set): the nodes prepared to be removed.
  4648. """
  4649. if not isinstance(remove_nodes, set):
  4650. if isinstance(remove_nodes, Iterable):
  4651. remove_nodes = set(remove_nodes)
  4652. else:
  4653. remove_nodes = {remove_nodes}
  4654. original_nodes = {n.node for n in remove_nodes}
  4655. core.graph_safe_remove_nodes(self.graph, original_nodes)
  4656. def resolve_hazard(self):
  4657. ordered_nodes = core.topology_sort(self.graph)
  4658. var_nodes = {}
  4659. for node in ordered_nodes:
  4660. if node.is_op() and node.op() is not None:
  4661. for each_var_name in node.op().input_arg_names():
  4662. if each_var_name not in var_nodes:
  4663. var_nodes[each_var_name] = [
  4664. self._find_node_by_name(node.inputs, each_var_name)
  4665. ]
  4666. for each_var_name in node.op().output_arg_names():
  4667. if each_var_name not in var_nodes:
  4668. var_nodes[each_var_name] = [
  4669. self._find_node_by_name(node.outputs, each_var_name)
  4670. ]
  4671. else:
  4672. var_nodes[each_var_name].append(
  4673. self._find_node_by_name(node.outputs, each_var_name)
  4674. )
  4675. self.graph.resolve_hazard(var_nodes)
  4676. def has_circle(self):
  4677. """
  4678. Check if the graph has a circle.
  4679. Returns:
  4680. bool: True if the graph has a circle else False.
  4681. """
  4682. return core.has_circle(self.graph)
  4683. def graph_num(self):
  4684. """
  4685. Count the number of unconnected graphs in this graph.
  4686. Returns:
  4687. int: the number of unconnected graphs.
  4688. """
  4689. return core.graph_num(self.graph)
  4690. def topology_sort(self):
  4691. """
  4692. Perform the topology sort operation on the graph.
  4693. Notes: the `graph` can not contain a circle.
  4694. Returns:
  4695. list(IrNode): nodes in topology order.
  4696. """
  4697. ordered_nodes = core.topology_sort(self.graph)
  4698. return [IrNode(n) for n in ordered_nodes]
  4699. def build_adjacency_list(self):
  4700. """
  4701. Build an adjacency list of operations for the `graph`.
  4702. Returns:
  4703. dict{IrNode: set(IrNode)}: the adjacency list.
  4704. """
  4705. adj_list = core.build_adjacency_list(self.graph)
  4706. wrapped_adj_list = {}
  4707. for k, v in adj_list.items():
  4708. wrapped_adj_list[IrNode(k)] = {IrNode(n) for n in v}
  4709. return wrapped_adj_list
  4710. def draw(self, save_path, name, marked_nodes=None, remove_ctr_var=True):
  4711. """
  4712. Draw the graph. If `dot` command is installed, the drawn graph
  4713. will be saved as pdf file type, otherwise dot file type is used.
  4714. Args:
  4715. save_path(str): the save path of drawn graph.
  4716. name(str): the name of drawn graph.
  4717. marked_nodes(set(IrNode)): nodes that are needed to be marked.
  4718. Default value is None.
  4719. remove_ctr_var(bool): If it is set True, all control variable nodes
  4720. in the graph will be removed. Default value is True.
  4721. """
  4722. def _convert_to_pdf(dot_file_path):
  4723. pdf_save_path = os.path.splitext(dot_file_path)[0] + ".pdf"
  4724. exited_code = subprocess.call(
  4725. ["dot", "-Tpdf", dot_file_path, "-o", pdf_save_path]
  4726. )
  4727. if exited_code != 0:
  4728. print("The dot command is needed for creating pdf files.")
  4729. print(f"The {dot_file_path} is saved as the dot filetype.")
  4730. remove_ctr_vars = set()
  4731. if remove_ctr_var:
  4732. for node in self.all_var_nodes():
  4733. if node.is_ctrl_var():
  4734. remove_ctr_vars.add(node)
  4735. self.safe_remove_nodes(remove_ctr_vars)
  4736. print(f"Total ops num = {len(self.all_op_nodes())}.")
  4737. if marked_nodes is not None:
  4738. if not isinstance(marked_nodes, set):
  4739. if isinstance(marked_nodes, Iterable):
  4740. marked_nodes = set(marked_nodes)
  4741. else:
  4742. marked_nodes = {marked_nodes}
  4743. marked_nodes = {n.node for n in marked_nodes}
  4744. remove_ctr_vars = {n.node for n in remove_ctr_vars}
  4745. marked_nodes = marked_nodes - remove_ctr_vars
  4746. if self.graph.has("__graphviz__marked_node__"):
  4747. self.graph.erase("__graphviz__marked_node__")
  4748. self.graph.set("__graphviz__marked_node__", marked_nodes)
  4749. if not os.path.exists(save_path):
  4750. os.makedirs(save_path)
  4751. viz_dot_path = os.path.join(save_path, name) + ".dot"
  4752. viz_pass = core.get_pass("graph_viz_pass")
  4753. viz_pass.set("graph_viz_path", viz_dot_path)
  4754. viz_pass.apply(self.graph)
  4755. _convert_to_pdf(viz_dot_path)
  4756. def to_program(self):
  4757. """
  4758. Convert the graph into a Program.
  4759. WARN: When the graph includes backward operator nodes, the
  4760. conversion process may be failed. Usually, this function is
  4761. only used to convert a test graph.
  4762. Returns:
  4763. Program: a program converted from the graph.
  4764. """
  4765. convert_pass = core.get_pass("graph_to_program_pass")
  4766. desc = core.ProgramDesc()
  4767. convert_pass.set_not_owned("program", desc)
  4768. convert_pass.apply(self.graph)
  4769. program = Program._construct_from_desc(desc)
  4770. return program
  4771. def _find_node_by_name(self, nodes, node_name):
  4772. """
  4773. Find a node in the giving nodes set by the name.
  4774. """
  4775. target_node = None
  4776. for n in nodes:
  4777. if n.name() == node_name:
  4778. target_node = n
  4779. assert target_node is not None, (
  4780. "Cannot find the target node (%s)in the giving set." % node_name
  4781. )
  4782. return target_node
  4783. def _update_desc_attr(self, desc, name, val):
  4784. """
  4785. Update the value of desc's attribute by attribute's name.
  4786. """
  4787. if isinstance(val, Variable):
  4788. desc.set_var_attr(name, val.desc)
  4789. elif isinstance(val, list) and _all_is_type(val, Variable):
  4790. desc.set_vars_attr(name, [v.desc for v in val])
  4791. elif isinstance(val, Block):
  4792. desc.set_block_attr(name, val.desc)
  4793. elif isinstance(val, list) and val and _all_is_type(val, Block):
  4794. desc.set_blocks_attr(name, [v.desc for v in val])
  4795. elif isinstance(val, (core.BlockDesc, core.ProgramDesc)):
  4796. desc.set_serialized_attr(name, val.serialize_to_string())
  4797. else:
  4798. desc._set_attr(name, val)
  4799. class Program:
  4800. """
  4801. Create Python Program. It has at least one :ref:`api_guide_Block_en`, when the
  4802. control flow op like conditional_block, while :ref:`api_paddle_base_layers_While` is included,
  4803. it will contain nested block.
  4804. Please reference the
  4805. `framework.proto <https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/base/framework/framework.proto>`_
  4806. for details.
  4807. A set of Program usually contains startup program and main program.
  4808. A startup program is set to contain some initial work, eg. initialize the ``Parameter``, and the main
  4809. program will contain the network structure and vars for train.
  4810. A set of Program can be used for test or train, in train program ,
  4811. Paddle will contain all content to build a train network, in test
  4812. program Paddle will prune some content which is irrelevant to test, eg.
  4813. backward ops and vars.
  4814. **Notes**:
  4815. **we have** :ref:`api_paddle_base_framework_default_startup_program` **and** :ref:`api_paddle_base_framework_default_main_program`
  4816. **by default, a pair of them will shared the parameters. The** :ref:`api_paddle_base_framework_default_startup_program` **only run once to initialize parameters,**
  4817. :ref:`api_paddle_base_framework_default_main_program` **run in every mini batch and adjust the weights.**
  4818. Returns:
  4819. Program: An empty Program.
  4820. Examples:
  4821. .. code-block:: python
  4822. >>> import paddle
  4823. >>> import paddle.static as static
  4824. >>> paddle.enable_static()
  4825. >>> main_program = static.Program()
  4826. >>> startup_program = static.Program()
  4827. >>> with static.program_guard(main_program=main_program, startup_program=startup_program):
  4828. ... x = static.data(name="x", shape=[-1, 784], dtype='float32')
  4829. ... y = static.data(name="y", shape=[-1, 1], dtype='int32')
  4830. ... z = static.nn.fc(name="fc", x=x, size=10, activation="relu")
  4831. >>> print("main program is: {}".format(main_program))
  4832. >>> print("start up program is: {}".format(startup_program))
  4833. """
  4834. def __init__(self):
  4835. self.desc = core.ProgramDesc()
  4836. self.blocks = [Block(self, 0)]
  4837. self.current_block_idx = 0
  4838. global global_prog_seed
  4839. self._seed = global_prog_seed
  4840. self._current_role = core.op_proto_and_checker_maker.OpRole.Forward
  4841. self.__op_role_var = []
  4842. # for distribute training
  4843. # _is_distributed = True if under distributed training
  4844. self._is_distributed = False
  4845. # _is_chief = True if the trainer is the first one, usually No.0
  4846. self._is_chief = False
  4847. # _parameters_on_pservers records all the parameters distributed on parameter servers.
  4848. self._parameters_on_pservers = None
  4849. # _endpoints is a list about parameter servers ip:port, such as ["ip:port","ip:port"]
  4850. self._endpoints = []
  4851. # if current role is parameter server, the _ps_endpoint is its "ip:port"
  4852. self._ps_endpoint = None
  4853. # trainers_endpoints, it is used for distribution.
  4854. self._trainers_endpoints = []
  4855. # the distributed lookup table names
  4856. self._distributed_lookup_table = None
  4857. # use Deep gradient compression or not
  4858. self._enable_dgc = False
  4859. self._use_lamb = False
  4860. self._nccl_comm_num = 1
  4861. self._use_hierarchical_allreduce = False
  4862. self._hierarchical_allreduce_inter_nranks = 0
  4863. # if this program has been optimized by distributed optimizer
  4864. # fleet_opt will be given a value
  4865. self._fleet_opt = None
  4866. self._program_config = None
  4867. self._pass_applied = None
  4868. # assigned if this program has been parsed by a pipeline optimizer
  4869. self._pipeline_opt = None
  4870. self._pass_opt = None
  4871. # assigned if this program has been parsed by a heter pipeline parameter server optimizer
  4872. self._heter_pipeline_opt = None
  4873. # appending gradients times
  4874. self._appending_grad_times = 0
  4875. # identifier for auto checkpoint
  4876. self._name_generator = unique_name.UniqueNameGenerator()
  4877. self._auto_checkpoint_name = self._name_generator(
  4878. "__auto_checkpoint_program__"
  4879. )
  4880. # compiled program, i.e. Graph
  4881. self._graph = None
  4882. # to tag whether is startup_program
  4883. self._is_start_up_program_ = False
  4884. # distributed training combined with prim mechanism (prim is behind of distributed)
  4885. # after distributed partition, for subprogram or subgraph on a single card, decompose PHI grad ops into primitive ops
  4886. # _need_decomp, to tag whether this program needs to be decomposed
  4887. self._need_decomp = False
  4888. # _grad_var_to_var, a dict which recording the mapping of backward grad variable to forward variable
  4889. self._grad_var_to_var = None
  4890. def _find_var_class_kwargs(self, new_desc):
  4891. # NOTE: not all variables support shape/dtype/lod_level methods.
  4892. # For example: RAW, STEP_SCOPES, etc.
  4893. def get_var_desc_attr_or_none(var_desc, attr_name, allowed_types):
  4894. if var_desc.type() in allowed_types:
  4895. return getattr(var_desc, attr_name)()
  4896. else:
  4897. return None
  4898. old_desc = self.desc
  4899. all_new_vars = []
  4900. block_num = new_desc.num_blocks()
  4901. for idx in range(block_num):
  4902. if idx > (len(self.blocks) - 1):
  4903. self._create_block()
  4904. new_block_desc = new_desc.block(idx)
  4905. all_new_vars.append([])
  4906. block_new_vars = all_new_vars[-1]
  4907. for new_var_desc in new_block_desc.all_vars():
  4908. if self.blocks[idx].has_var(new_var_desc.name()):
  4909. old_var = self.blocks[idx].var(new_var_desc.name())
  4910. else:
  4911. old_var = None
  4912. kwargs = {
  4913. "type": new_var_desc.type(),
  4914. "name": new_var_desc.name(),
  4915. "shape": get_var_desc_attr_or_none(
  4916. new_var_desc,
  4917. "shape",
  4918. [
  4919. core.VarDesc.VarType.LOD_TENSOR,
  4920. core.VarDesc.VarType.SELECTED_ROWS,
  4921. core.VarDesc.VarType.LOD_TENSOR_ARRAY,
  4922. ],
  4923. ),
  4924. "dtype": get_var_desc_attr_or_none(
  4925. new_var_desc,
  4926. "dtype",
  4927. [
  4928. core.VarDesc.VarType.LOD_TENSOR,
  4929. core.VarDesc.VarType.SELECTED_ROWS,
  4930. core.VarDesc.VarType.LOD_TENSOR_ARRAY,
  4931. ],
  4932. ),
  4933. "lod_level": get_var_desc_attr_or_none(
  4934. new_var_desc,
  4935. "lod_level",
  4936. [
  4937. core.VarDesc.VarType.LOD_TENSOR,
  4938. core.VarDesc.VarType.LOD_TENSOR_ARRAY,
  4939. ],
  4940. ),
  4941. "error_clip": old_var.error_clip
  4942. if old_var is not None
  4943. else None,
  4944. "stop_gradient": old_var.stop_gradient
  4945. if old_var is not None
  4946. else False,
  4947. "is_data": old_var.is_data
  4948. if old_var is not None
  4949. else False,
  4950. "need_check_feed": new_var_desc.need_check_feed(),
  4951. "belong_to_optimizer": old_var.belong_to_optimizer
  4952. if old_var is not None
  4953. else False,
  4954. }
  4955. if isinstance(old_var, Parameter):
  4956. kwargs.update(
  4957. {
  4958. "trainable": old_var.trainable,
  4959. "optimize_attr": old_var.optimize_attr,
  4960. "regularizer": old_var.regularizer,
  4961. "do_model_average": old_var.do_model_average,
  4962. "need_clip": old_var.need_clip,
  4963. "is_distributed": old_var.is_distributed,
  4964. "is_parameter": old_var.is_parameter,
  4965. }
  4966. )
  4967. block_new_vars.append(
  4968. {
  4969. "class": Parameter,
  4970. "kwargs": copy.deepcopy(kwargs),
  4971. }
  4972. )
  4973. else:
  4974. kwargs["persistable"] = new_var_desc.persistable()
  4975. block_new_vars.append(
  4976. {
  4977. "class": Variable,
  4978. "kwargs": copy.deepcopy(kwargs),
  4979. }
  4980. )
  4981. return all_new_vars
  4982. def _rebuild_from_desc(self, desc):
  4983. all_new_vars = self._find_var_class_kwargs(desc)
  4984. block_num = desc.num_blocks()
  4985. assert block_num == len(all_new_vars)
  4986. assert block_num == self.desc.num_blocks()
  4987. # clear old blocks and desc
  4988. for idx in range(block_num):
  4989. block = self.blocks[idx]
  4990. block.vars.clear()
  4991. block.ops.clear()
  4992. for idx in range(block_num):
  4993. block_desc = self.blocks[idx].desc
  4994. new_block_desc = desc.block(idx)
  4995. block_desc._move_from(new_block_desc)
  4996. del desc
  4997. # add new vars first
  4998. for idx in range(block_num):
  4999. block = self.blocks[idx]
  5000. for new_var in all_new_vars[idx]:
  5001. clazz = new_var["class"]
  5002. kwargs = new_var["kwargs"]
  5003. kwargs["block"] = block
  5004. clazz(**kwargs)
  5005. # then append op
  5006. for idx in range(block_num):
  5007. block = self.blocks[idx]
  5008. block_desc = self.desc.block(idx)
  5009. for op_idx in range(block_desc.op_size()):
  5010. op_desc = block_desc.op(op_idx)
  5011. op = Operator(block=block, desc=op_desc)
  5012. block.ops.append(op)
  5013. def global_seed(self, seed=0):
  5014. """
  5015. Set global seed for Program
  5016. Returns:
  5017. None.
  5018. Examples:
  5019. .. code-block:: python
  5020. >>> import paddle
  5021. >>> import paddle.static as static
  5022. >>> paddle.enable_static()
  5023. >>> prog = static.default_main_program()
  5024. >>> print(prog.random_seed)
  5025. 0
  5026. >>> ## the default random seed is 0
  5027. >>> prog.global_seed(102)
  5028. >>> prog1 = static.default_main_program()
  5029. >>> print(prog1.random_seed)
  5030. 102
  5031. >>> ## the random seed is 102
  5032. """
  5033. global global_prog_seed
  5034. global_prog_seed = seed
  5035. self._seed = global_prog_seed
  5036. @property
  5037. def _op_role(self):
  5038. """
  5039. The operator role. In a enum {Forward, Backward, Optimize}.
  5040. Notes: this is a low level API. It is used only for ParallelExecutor to
  5041. duplicate or schedule operator to devices.
  5042. For example, the forward operator should be executed on every device.
  5043. The backward operator should be executed on every device and the
  5044. parameter gradient of backward (use :code:`_op_role_var` to get this
  5045. variable) operator should be merged to one device. The optimization
  5046. operators should be executed on only one device and broadcast the
  5047. optimization result, i.e., the new parameter, to every other device.
  5048. """
  5049. return self._current_role
  5050. @_op_role.setter
  5051. def _op_role(self, role):
  5052. self._current_role = role
  5053. @property
  5054. def _op_role_var(self):
  5055. """
  5056. The auxiliary variables for :code:`_op_role` property.
  5057. See Also: :code:`Program._op_role`'s documentation for details.
  5058. Notes: This is a very low-level API. Users should not use it directly.
  5059. """
  5060. return self.__op_role_var
  5061. @signature_safe_contextmanager
  5062. def _backward_role_guard(self):
  5063. tmp_role = self._current_role
  5064. OpRole = core.op_proto_and_checker_maker.OpRole
  5065. self._current_role = OpRole.Backward
  5066. try:
  5067. yield
  5068. finally:
  5069. self._current_role = tmp_role
  5070. @signature_safe_contextmanager
  5071. def _optimized_guard(self, param_and_grads):
  5072. """
  5073. A with guard to set :code:`Optimization` :code:`OpRole` and
  5074. :code:`OpRoleVar` automatically.
  5075. Notes: This is a very low level API. Users should not use it directly.
  5076. Args:
  5077. param_and_grads(list): The variables (names) to be optimized.
  5078. Examples:
  5079. >>> import paddle.base as base
  5080. >>> p, g = backward(...)
  5081. >>> with program._optimized_guard([p,g]):
  5082. >>> p = p - 0.001 * g
  5083. """
  5084. tmp_role = self._current_role
  5085. tmp_var = self.__op_role_var
  5086. OpRole = core.op_proto_and_checker_maker.OpRole
  5087. self._current_role = OpRole.Optimize
  5088. self.__op_role_var = [
  5089. var.name if isinstance(var, Variable) else var
  5090. for var in param_and_grads
  5091. ]
  5092. try:
  5093. yield
  5094. finally:
  5095. self.__op_role_var = tmp_var
  5096. self._current_role = tmp_role
  5097. @signature_safe_contextmanager
  5098. def _lr_schedule_guard(self, is_with_opt=False):
  5099. """
  5100. A with guard to set :code:`LRSched` :code:`OpRole` and
  5101. :code:`OpRoleVar` automatically. The :code:`OpRoleVar` is
  5102. set to the target learning rate.
  5103. Notes: This is a very low level API. Users should not use it directly.
  5104. Args:
  5105. is_with_opt: Only set to true if these ops a in the middle
  5106. of a bunch of optimize ops so that it can be treated
  5107. correctly. For example, sgd->lr_op->sgd->lr_op->sgd.
  5108. Examples:
  5109. >>> import paddle.base as base
  5110. >>> p, g = backward(...)
  5111. >>> with program.lr_schedule_guard():
  5112. >>> lr = lr * decay
  5113. """
  5114. tmp_role = self._current_role
  5115. tmp_var = self.__op_role_var
  5116. OpRole = core.op_proto_and_checker_maker.OpRole
  5117. self._current_role = OpRole.LRSched
  5118. if is_with_opt:
  5119. self._current_role = int(OpRole.LRSched) | int(OpRole.Optimize)
  5120. # TODO(typhoonzero): how to set target learning rate var
  5121. self.__op_role_var = []
  5122. try:
  5123. yield
  5124. finally:
  5125. self.__op_role_var = tmp_var
  5126. self._current_role = tmp_role
  5127. def __str__(self):
  5128. """
  5129. Get the protobuf debug string of this Program.
  5130. Returns:
  5131. (str): The protobuf debug string.
  5132. Raises:
  5133. ValueError: If any of required fields is not set.
  5134. """
  5135. return self._to_readable_code()
  5136. def _to_readable_code(self, skip_op_callstack=True):
  5137. """
  5138. Get readable debug string of Program.
  5139. .. note::
  5140. If you want to get the debug string in protobuf format,
  5141. please use :code:`to_string` method.
  5142. Args:
  5143. skip_op_callstack(bool): whether to skip parsing Operator's attribute
  5144. op_callstack, default value is True
  5145. Returns:
  5146. string: The formatted Program string.
  5147. Examples:
  5148. .. code-block:: python
  5149. >>> import paddle
  5150. >>> import paddle.static as static
  5151. >>> paddle.enable_static()
  5152. >>> cur_program = static.Program()
  5153. >>> cur_block = cur_program.current_block()
  5154. >>> new_var = cur_block.create_var(name="X",
  5155. ... shape=[-1, 23, 48],
  5156. ... dtype='float32')
  5157. >>> new_op = cur_block.append_op(type="abs",
  5158. ... inputs={"X": [new_var]},
  5159. ... outputs={"Out": [new_var]})
  5160. >>> print(cur_program._to_readable_code())
  5161. """
  5162. assert isinstance(
  5163. skip_op_callstack, bool
  5164. ), f"skip_op_callstack parameter's type is error, expect bool, received {type(skip_op_callstack)}"
  5165. program_str = ""
  5166. for block in self.blocks:
  5167. program_str += block._to_readable_code(skip_op_callstack)
  5168. program_str += "\n"
  5169. return program_str
  5170. def to_string(self, throw_on_error, with_details=False):
  5171. """
  5172. To debug string.
  5173. Args:
  5174. throw_on_error (bool): raise Value error when any of required fields is not set.
  5175. with_details (bool): True if more details about variables and parameters, e.g., :code:`trainable`, :code:`optimize_attr`, need to print.
  5176. Returns:
  5177. str: The debug string describe current Program.
  5178. Raises:
  5179. ValueError: If any of required fields is not set and throw_on_error is True.
  5180. Examples:
  5181. .. code-block:: python
  5182. >>> import paddle
  5183. >>> import paddle.static as static
  5184. >>> paddle.enable_static()
  5185. >>> prog = static.default_main_program()
  5186. >>> x = static.data(name="X", shape=[2,3], dtype="float32")
  5187. >>> pred = static.nn.fc(x, size=3)
  5188. >>> prog_string = prog.to_string(throw_on_error=True, with_details=False)
  5189. >>> prog_string_with_details = prog.to_string(throw_on_error=False, with_details=True)
  5190. >>> print("program string without detail: {}".format(prog_string))
  5191. >>> print("program string with detail: {}".format(prog_string_with_details))
  5192. """
  5193. assert isinstance(
  5194. throw_on_error, bool
  5195. ), f"The type of throw_on_error parameter is wrong, expected bool, but received {type(throw_on_error)}."
  5196. assert isinstance(
  5197. with_details, bool
  5198. ), f"The type of with_details parameter is wrong, expected bool, but received {type(with_details)}."
  5199. if with_details:
  5200. res_str = ""
  5201. for block in self.blocks:
  5202. res_str += block.to_string(throw_on_error, with_details)
  5203. protostr = self.desc.serialize_to_string()
  5204. proto = framework_pb2.ProgramDesc.FromString(bytes(protostr))
  5205. res_str += (
  5206. "version {\n "
  5207. + textwrap.indent(
  5208. _debug_string_(proto.version, throw_on_error), " "
  5209. )
  5210. + "}\n"
  5211. )
  5212. res_str += (
  5213. "op_version_map {\n "
  5214. + textwrap.indent(
  5215. _debug_string_(proto.op_version_map, throw_on_error), " "
  5216. )
  5217. + "}\n"
  5218. )
  5219. else:
  5220. protostr = self.desc.serialize_to_string()
  5221. proto = framework_pb2.ProgramDesc.FromString(bytes(protostr))
  5222. res_str = _debug_string_(proto, throw_on_error)
  5223. return res_str
  5224. def _get_desc(self):
  5225. """
  5226. Get the C++ side of `ProgramDesc` object pointer. The C++ object is
  5227. exposed by :code:`pybind`.
  5228. Notes: This is a very low level API. Users should not use this API
  5229. directly.
  5230. """
  5231. return self.desc
  5232. def _version(self):
  5233. return self.desc._version()
  5234. def clone(self, for_test=False):
  5235. """
  5236. .. note:::
  5237. 1. :code:`Program.clone()` method DOES NOT clone :ref:`api_paddle_io_DataLoader` .
  5238. 2. Recommend you to use :code:`clone` before using :code:`Optimizer.minimize` .
  5239. 3. This API has no effect in Dygraph Mode.
  5240. Create a new Program with forward content of original one when ``for_test=True``.
  5241. Create a new Program as same as the original one when ``for_test=False``.
  5242. Some operators, e.g., :ref:`api_paddle_base_layers_batch_norm` , behave differently between
  5243. training and testing. They have an attribute, :code:`is_test`, to
  5244. control this behaviour. This method will change the :code:`is_test`
  5245. attribute of them to :code:`True` when :code:`for_test=True`.
  5246. * Set for_test to False when you want to clone the program for training.
  5247. * Set for_test to True when you want to clone the program for testing.
  5248. We will prune the backward and optimize part of the program when you
  5249. use :code:`clone` after :code:`Optimizer.minimize`, but we still
  5250. recommend you to use :code:`clone` before using :code:`Optimizer.minimize`.
  5251. Examples:
  5252. .. code-block:: python
  5253. :name: code-example-1
  5254. >>> import paddle
  5255. >>> import paddle.static as static
  5256. >>> paddle.enable_static()
  5257. >>> img = static.data(name='image', shape=[None, 784])
  5258. >>> pred = static.nn.fc(x=img, size=10, activation='relu')
  5259. >>> loss = paddle.mean(pred)
  5260. >>> # Here we use clone before Momentum
  5261. >>> test_program = static.default_main_program().clone(for_test=True)
  5262. >>> optimizer = paddle.optimizer.Momentum(learning_rate=0.01, momentum=0.9)
  5263. >>> optimizer.minimize(loss)
  5264. Args:
  5265. for_test (bool): True if change the :code:`is_test` attribute of operators to :code:`True`
  5266. and prune the backward and optimize part of the program. The default value is :code:`False` .
  5267. Returns:
  5268. Program: A new Program with forward content of original one when ``for_test=True``. A new Program as same as the original one when ``for_test=False``
  5269. Examples:
  5270. .. note::
  5271. The Program's order maybe different after :code:`clone` and
  5272. this will not affect your training or testing progress. In the following
  5273. example we give you an simple method :code:`print_prog(program)` to
  5274. print Program Descs inorder to make sure you have same print result
  5275. after :code:`clone`:
  5276. .. code-block:: python
  5277. :name: code-example-2
  5278. >>> import paddle
  5279. >>> def print_prog(prog):
  5280. ... for name, value in sorted(prog.block(0).vars.items()):
  5281. ... print(value)
  5282. ... for op in prog.block(0).ops:
  5283. ... print("op type is {}".format(op.type))
  5284. ... print("op inputs are {}".format(op.input_arg_names))
  5285. ... print("op outputs are {}".format(op.output_arg_names))
  5286. ... for key, value in sorted(op.all_attrs().items()):
  5287. ... if key not in ['op_callstack', 'op_role_var']:
  5288. ... print(" [ attrs: {}: {} ]".format(key, value))
  5289. 1. To clone a test program, the sample code is:
  5290. .. code-block:: python
  5291. :name: code-example-3
  5292. >>> import paddle
  5293. >>> import paddle.static as static
  5294. >>> import paddle.utils as utils
  5295. >>> import paddle.nn.functional as F
  5296. >>> paddle.enable_static()
  5297. >>> def print_prog(prog):
  5298. ... for name, value in sorted(prog.block(0).vars.items()):
  5299. ... print(value)
  5300. ... for op in prog.block(0).ops:
  5301. ... print("op type is {}".format(op.type))
  5302. ... print("op inputs are {}".format(op.input_arg_names))
  5303. ... print("op outputs are {}".format(op.output_arg_names))
  5304. ... for key, value in sorted(op.all_attrs().items()):
  5305. ... if key not in ['op_callstack', 'op_role_var']:
  5306. ... print(" [ attrs: {}: {} ]".format(key, value))
  5307. >>> train_program = static.Program()
  5308. >>> startup_program = static.Program()
  5309. >>> # startup_program is used to do some parameter init work,
  5310. >>> # and main program is used to hold the network
  5311. >>> with static.program_guard(train_program, startup_program):
  5312. ... with utils.unique_name.guard():
  5313. ... img = static.data(name='image', shape=[None, 784])
  5314. ... hidden = static.nn.fc(x=img, size=200, activation='relu')
  5315. ... hidden = F.dropout(hidden, p=0.5)
  5316. ... loss = F.cross_entropy(
  5317. ... input=static.nn.fc(x=hidden, size=10, activation='softmax'),
  5318. ... label=static.data(name='label', shape=[1], dtype='int64'))
  5319. ... avg_loss = paddle.mean(loss)
  5320. ... test_program = train_program.clone(for_test=True)
  5321. >>> print_prog(test_program)
  5322. >>> # Due to parameter sharing usage for train and test, so we need to use startup program of train
  5323. >>> # instead of using test startup program, while nothing is in test's startup program
  5324. >>> # In Paddle we will share weights by using the same Tensor name. In train and test program
  5325. >>> # all parameters will have the same name and this can make train and test program sharing parameters,
  5326. >>> # that's why we need to use startup program of train. And for startup program of test, it has nothing,
  5327. >>> # since it is a new program.
  5328. >>> with static.program_guard(train_program, startup_program):
  5329. ... with utils.unique_name.guard():
  5330. ... sgd = paddle.optimizer.SGD(learning_rate=1e-3)
  5331. ... sgd.minimize(avg_loss)
  5332. 2. The clone method can be avoid if you create program for training and program for testing individually.
  5333. .. code-block:: python
  5334. :name: code-example-4
  5335. >>> import paddle
  5336. >>> import paddle.static as static
  5337. >>> import paddle.utils as utils
  5338. >>> import paddle.nn.functional as F
  5339. >>> paddle.enable_static()
  5340. >>> def print_prog(prog):
  5341. ... for name, value in sorted(prog.block(0).vars.items()):
  5342. ... print(value)
  5343. ... for op in prog.block(0).ops:
  5344. ... print("op type is {}".format(op.type))
  5345. ... print("op inputs are {}".format(op.input_arg_names))
  5346. ... print("op outputs are {}".format(op.output_arg_names))
  5347. ... for key, value in sorted(op.all_attrs().items()):
  5348. ... if key not in ['op_callstack', 'op_role_var']:
  5349. ... print(" [ attrs: {}: {} ]".format(key, value))
  5350. >>> def network():
  5351. ... img = static.data(name='image', shape=[None, 784])
  5352. ... hidden = static.nn.fc(x=img, size=200, activation='relu')
  5353. ... hidden = F.dropout(hidden, p=0.5)
  5354. ... loss = F.cross_entropy(
  5355. ... input=static.nn.fc(x=hidden, size=10, activation='softmax'),
  5356. ... label=static.data(name='label', shape=[1], dtype='int64'))
  5357. ... avg_loss = paddle.mean(loss)
  5358. ... return avg_loss
  5359. >>> train_program_2 = static.Program()
  5360. >>> startup_program_2 = static.Program()
  5361. >>> test_program_2 = static.Program()
  5362. >>> with static.program_guard(train_program_2, startup_program_2):
  5363. ... with utils.unique_name.guard():
  5364. ... avg_loss = network()
  5365. ... sgd = paddle.optimizer.SGD(learning_rate=1e-3)
  5366. ... sgd.minimize(avg_loss)
  5367. >>> # the test startup program is not used.
  5368. >>> with static.program_guard(test_program_2, startup_program_2):
  5369. ... with utils.unique_name.guard():
  5370. ... avg_loss = network()
  5371. >>> print_prog(test_program_2)
  5372. The two code snippets above will generate and print same programs.
  5373. """
  5374. # NOTE(zhiqiu): we sync the original program first, since its program may diff with
  5375. # its desc due to modifying desc in c++ space. E.g. save op will add kLookupTablePath in desc.
  5376. self._sync_with_cpp()
  5377. pruned_origin_block_id_map = None
  5378. if for_test:
  5379. forward_prog = Program()
  5380. forward_prog.desc, pruned_origin_block_id_map = core.prune_backward(
  5381. self.desc
  5382. )
  5383. forward_prog.blocks = [
  5384. Block(forward_prog, i)
  5385. for i in range(forward_prog.desc.num_blocks())
  5386. ]
  5387. forward_prog._sync_with_cpp()
  5388. p = forward_prog._inference_optimize(prune_read_op=False)
  5389. else:
  5390. p = Program()
  5391. p.current_block_idx = self.current_block_idx
  5392. p._seed = self._seed
  5393. p.desc = core.ProgramDesc(self.desc)
  5394. p.blocks = [Block(p, i) for i in range(self.desc.num_blocks())]
  5395. p._current_role = self._current_role
  5396. p.__op_role_var = self.__op_role_var
  5397. p._appending_grad_times = self._appending_grad_times
  5398. if hasattr(self, "lr_scheduler"):
  5399. p.lr_scheduler = self.lr_scheduler
  5400. if hasattr(self, "_pipeline_opt"):
  5401. p._pipeline_opt = self._pipeline_opt
  5402. if hasattr(self, "_pass_opt"):
  5403. p._pass_opt = self._pass_opt
  5404. if hasattr(self, "_need_decomp"):
  5405. p._need_decomp = self._need_decomp
  5406. if hasattr(self, "_grad_var_to_var"):
  5407. p._grad_var_to_var = self._grad_var_to_var
  5408. # NOTE(zhiqiu): we sync the cloned program, to update its program by
  5409. # its desc.
  5410. p._sync_with_cpp()
  5411. p._copy_param_info_from(self)
  5412. p._copy_data_info_from(self, pruned_origin_block_id_map)
  5413. p._copy_dist_param_info_from(self)
  5414. p._copy_operator_info_from(self)
  5415. p._name_generator = self._name_generator.clone()
  5416. return p
  5417. @signature_safe_contextmanager
  5418. def switch_name_generator_guard(self, new_generator):
  5419. if isinstance(new_generator, str):
  5420. new_generator = unique_name.UniqueNameGenerator(new_generator)
  5421. elif isinstance(new_generator, bytes):
  5422. new_generator = unique_name.UniqueNameGenerator(
  5423. new_generator.decode()
  5424. )
  5425. old_generator = self._name_generator
  5426. self._name_generator = new_generator
  5427. try:
  5428. yield
  5429. finally:
  5430. self._name_generator = old_generator
  5431. def _prune(self, targets):
  5432. """
  5433. Prune operators and variables which are not needed to generate
  5434. :code:`targets`.
  5435. Notes: This is a very low level API. Users should not use this API
  5436. directly. This API is in flux and not stable.
  5437. Args:
  5438. targets(list|Variable|Operator): A list of variables, operators, or variable names
  5439. need to be pruned
  5440. Returns:
  5441. Program: A new, pruned program.
  5442. """
  5443. return self._prune_with_input([], targets)
  5444. def _prune_with_input(self, feeded_var_names, targets):
  5445. """
  5446. Prune operators and variables which are not needed to generate
  5447. :code:`targets`. Prune operators and variables which are needed
  5448. to generate feeded_var
  5449. Notes: This is a very low level API. Users should not use this API
  5450. directly. This API is in flux and not stable.
  5451. Args:
  5452. feeded_var_names(list|str): A list of variable names from where
  5453. pruning start. If it is set as [], this API works just like _prune()
  5454. targets(list|Variable|Operator): A list of variables, operators, or variable names
  5455. need to be pruned
  5456. Returns:
  5457. Program: A new, pruned program.
  5458. """
  5459. # NOTE(zhiqiu): we sync the original program first, since its program may diff with
  5460. # its desc due to modifying desc in c++ space. E.g. save op will add kLookupTablePath in desc.
  5461. self._sync_with_cpp()
  5462. if not isinstance(feeded_var_names, list):
  5463. feeded_var_names = [feeded_var_names]
  5464. if not isinstance(targets, list):
  5465. targets = [targets]
  5466. for var in feeded_var_names:
  5467. if not isinstance(var, str):
  5468. raise ValueError(
  5469. "All feeded_var_names of Program._prune_with_input() can only be "
  5470. "str, but received %s." % type(var)
  5471. )
  5472. # find out all variables that can be generated or updated with given feed
  5473. generatable_vars = set()
  5474. for idx, op in enumerate(self.global_block().ops):
  5475. runnable_op = True
  5476. for name in op.input_arg_names:
  5477. if not self.global_block().has_var(name):
  5478. continue
  5479. if self.global_block().var(name).persistable:
  5480. continue
  5481. if name not in generatable_vars.union(feeded_var_names):
  5482. runnable_op = False
  5483. break
  5484. if runnable_op:
  5485. generatable_vars = generatable_vars.union(op.output_arg_names)
  5486. targets_idx = []
  5487. for t in targets:
  5488. if not isinstance(t, Operator):
  5489. if isinstance(t, Variable):
  5490. name = t.name
  5491. elif isinstance(t, str):
  5492. name = str(t)
  5493. else:
  5494. raise ValueError(
  5495. "All targets of Program._prune_with_input() can only be "
  5496. "Variable or Operator, but received %s." % type(t)
  5497. )
  5498. # NOTE(zhiqiu): For variable to be fed in fetch_list, there two cases:
  5499. # (1) the variable is leaf, it has no op that generates it;
  5500. # (2) the variable is not leaf, and we need to prune the op that generates it.
  5501. # In both cases, wo can just skip target_op of that it.
  5502. if name in feeded_var_names:
  5503. # however if the var is also updated by a runnable op, will shall keep it
  5504. if name not in generatable_vars:
  5505. continue
  5506. # After transpiler processing, the op that output this
  5507. # variable maybe has been changed, so t.op is not reliable
  5508. # and we need to find the current op that generate this
  5509. # variable here.
  5510. target_op = None
  5511. global_block = self.global_block()
  5512. for idx, op in enumerate(global_block.ops):
  5513. if name in op.output_arg_names:
  5514. # NOTE(zhiqiu): Find op that generate target name.
  5515. # Skip optimize op except for optimize op in targets,
  5516. # since optimize op generates parameters.
  5517. if op._is_optimize_op() and op not in targets:
  5518. continue
  5519. else:
  5520. target_op = op
  5521. if target_op is not None:
  5522. targets_idx.append([target_op.block.idx, target_op.idx])
  5523. else:
  5524. targets_idx.append([t.block.idx, t.idx])
  5525. res = Program()
  5526. res.desc, pruned_origin_block_id_map = core.prune(
  5527. self.desc, set(feeded_var_names), targets_idx
  5528. )
  5529. res.blocks = [Block(res, i) for i in range(res.desc.num_blocks())]
  5530. res._sync_with_cpp()
  5531. res._copy_param_info_from(self)
  5532. res._copy_data_info_from(self, pruned_origin_block_id_map)
  5533. res._copy_dist_param_info_from(self)
  5534. res._copy_operator_info_from(self)
  5535. return res
  5536. def _inference_optimize(self, prune_read_op=True):
  5537. """
  5538. This method will create a new program and do following adjustments on it:
  5539. 1. Remove all reader variables and their creator ops if exist.
  5540. 2. Remove the :code:`read_op` if exists.
  5541. 3. change the :code:`is_test`
  5542. attribute of operators to :code:`True`. All the :code:`Parameter`
  5543. information will be lost.
  5544. Args:
  5545. prune_read_op(bool): remove the read ops that are added by py_reader
  5546. for cpp inference library
  5547. Notes: This API is a very low level API. Use
  5548. :code:`Program.clone(for_test=True)` instead.
  5549. Returns:
  5550. Program: The new program.
  5551. """
  5552. res = Program()
  5553. res.desc = core.ProgramDesc(self.desc)
  5554. # remove all readers and the read_op if exist
  5555. read_op_idx = 0
  5556. root_block = res.desc.block(0)
  5557. if prune_read_op:
  5558. while True:
  5559. if (
  5560. read_op_idx >= root_block.op_size()
  5561. or root_block.op(read_op_idx).type() == "read"
  5562. ):
  5563. break
  5564. read_op_idx += 1
  5565. if read_op_idx < root_block.op_size():
  5566. root_block._remove_op(0, read_op_idx + 1)
  5567. for var in root_block.all_vars():
  5568. if var.type() == core.VarDesc.VarType.READER:
  5569. root_block._remove_var(var.name().encode())
  5570. # change all `is_test` attributes to True
  5571. for i in range(res.desc.num_blocks()):
  5572. block = res.desc.block(i)
  5573. for j in range(block.op_size()):
  5574. op = block.op(j)
  5575. if op.has_attr("is_test"):
  5576. op._set_bool_attr("is_test", True)
  5577. if op.type() == "batch_norm":
  5578. # Remove the output ReserveSpace of batch_norm if exists.
  5579. op.remove_output("ReserveSpace")
  5580. res.blocks = [Block(res, i) for i in range(res.desc.num_blocks())]
  5581. res._sync_with_cpp()
  5582. return res
  5583. def _remove_training_info(self, clip_extra=True):
  5584. """
  5585. This method will create a new program and do following adjustments on it:
  5586. 1. Remove all variable's `is_parameter` attribute if exist.
  5587. 2. Remove all variable's `stop_gradient` attribute if exist.
  5588. Notes: This API is a very low level API.
  5589. Returns:
  5590. Program: The new program.
  5591. """
  5592. res = Program()
  5593. res.desc = core.ProgramDesc(self.desc)
  5594. res.blocks = [Block(res, i) for i in range(res.desc.num_blocks())]
  5595. res._sync_with_cpp()
  5596. # Note: The op_role and op_role_var cann't be deleted currently,
  5597. # and we will try to remove them in the future.
  5598. common_clipped_attrs_list = ["op_callstack", "with_quant_attr"]
  5599. for i in range(res.desc.num_blocks()):
  5600. block = res.desc.block(i)
  5601. for var in block.all_vars():
  5602. var.clear_is_parameter()
  5603. var.clear_stop_gradient()
  5604. if not clip_extra:
  5605. continue
  5606. for op_idx in range(0, block.op_size()):
  5607. op = block.op(op_idx)
  5608. if op.type() not in OpProtoHolder.instance().op_proto_map:
  5609. continue
  5610. extra_attrs_map = core.get_op_extra_attrs(op.type())
  5611. proto = OpProtoHolder.instance().get_op_proto(op.type())
  5612. remove_input_list = []
  5613. for name in op.input_names():
  5614. find = False
  5615. for input_proto in proto.inputs:
  5616. if input_proto.name != name:
  5617. continue
  5618. if input_proto.extra:
  5619. remove_input_list.append(name)
  5620. find = True
  5621. break
  5622. if not find:
  5623. remove_input_list.append(name)
  5624. # The extra input of op will be removed in the future
  5625. # for name in remove_input_list:
  5626. # op.remove_input(name)
  5627. remove_output_list = []
  5628. for name in op.output_names():
  5629. find = False
  5630. for output_proto in proto.outputs:
  5631. if output_proto.name != name:
  5632. continue
  5633. if output_proto.extra:
  5634. remove_output_list.append(name)
  5635. find = True
  5636. break
  5637. if not find:
  5638. remove_output_list.append(name)
  5639. # The extra output of op will be removed in the future
  5640. for name in remove_output_list:
  5641. op.remove_output(name)
  5642. op_quant_name = (
  5643. core.op_proto_and_checker_maker.kOpWithQuantAttrName()
  5644. )
  5645. quant = (
  5646. bool(op.attr(op_quant_name))
  5647. if op_quant_name in op.attr_names()
  5648. else False
  5649. )
  5650. quant_attrs = [
  5651. op_quant_name,
  5652. "quantization_type",
  5653. "skip_quant",
  5654. "activation_bits",
  5655. "bit_length",
  5656. "quantize_weight_bits",
  5657. "weight_quant_scale",
  5658. ]
  5659. for extra_attr_name in extra_attrs_map.keys():
  5660. op.remove_attr(extra_attr_name)
  5661. remove_attr_list = []
  5662. for name in op.attr_names():
  5663. if quant:
  5664. if name in quant_attrs:
  5665. continue
  5666. if name.endswith("_threshold"):
  5667. continue
  5668. if len(extra_attrs_map) > 0:
  5669. if name in common_clipped_attrs_list:
  5670. op.remove_attr(name)
  5671. continue
  5672. find = False
  5673. for attr_proto in proto.attrs:
  5674. if attr_proto.name != name:
  5675. continue
  5676. find = True
  5677. break
  5678. if not find:
  5679. remove_attr_list.append(name)
  5680. for name in remove_attr_list:
  5681. op.remove_attr(name)
  5682. return res
  5683. @staticmethod
  5684. def parse_from_string(binary_str):
  5685. """
  5686. .. note::
  5687. 1. All information about parameters will be lost after serialization;
  5688. 2. This API has no effect in Dygraph mode.
  5689. Deserialize a Program from `protobuf <https://en.wikipedia.org/wiki/Protocol_Buffers>`_ binary string.
  5690. This method always use to save and load model
  5691. Args:
  5692. binary_str_type (str): the binary protobuf string.
  5693. Returns:
  5694. Program: A deserialized Program.
  5695. Examples:
  5696. .. code-block:: python
  5697. >>> import paddle
  5698. >>> import paddle.static as static
  5699. >>> paddle.enable_static()
  5700. >>> startup_prog = static.Program()
  5701. >>> main_prog = static.Program()
  5702. >>> with static.program_guard(startup_prog, main_prog):
  5703. ... x = static.data(name='X', shape=[1000, 784], dtype='float32')
  5704. ... y = static.data(name='Y', shape=[784, 100], dtype='float32')
  5705. ... z = paddle.matmul(x=x, y=y)
  5706. ... binary_str = static.default_main_program().desc.serialize_to_string()
  5707. ... prog_restored = static.default_main_program().parse_from_string(binary_str)
  5708. ... print(static.default_main_program())
  5709. ... print(prog_restored)
  5710. """
  5711. p = Program()
  5712. p.desc = core.ProgramDesc(binary_str)
  5713. p.blocks = [Block(p, i) for i in range(p.desc.num_blocks())]
  5714. p._sync_with_cpp()
  5715. return p
  5716. @staticmethod
  5717. def _construct_from_desc(desc):
  5718. """
  5719. Construct a program from program desc.
  5720. Args:
  5721. desc(core.ProgramDesc): The program desc for constructing.
  5722. Returns:
  5723. Program: A program.
  5724. """
  5725. p = Program()
  5726. p.desc = desc
  5727. p.blocks = [Block(p, i) for i in range(p.desc.num_blocks())]
  5728. p._sync_with_cpp()
  5729. return p
  5730. @property
  5731. def random_seed(self):
  5732. """
  5733. The default random seed for random operators in Program. ``0`` means get
  5734. the random seed from random device.
  5735. .. note::
  5736. It must be set before the operators have been added.
  5737. Returns:
  5738. int64: Random seed in current Program
  5739. Examples:
  5740. .. code-block:: python
  5741. >>> import paddle
  5742. >>> import paddle.static as static
  5743. >>> import paddle.nn.functional as F
  5744. >>> paddle.enable_static()
  5745. >>> prog = static.default_main_program()
  5746. >>> random_seed = prog.random_seed
  5747. >>> x_var = static.data(name="X", shape=[3,3], dtype="float32")
  5748. >>> print(random_seed)
  5749. 0
  5750. >>> ## the default random seed is 0
  5751. >>> # Here we need to set random seed before we use paddle.nn.functional.dropout
  5752. >>> prog.random_seed = 1
  5753. >>> z_var = F.dropout(x_var, 0.7)
  5754. >>> print(prog.random_seed)
  5755. 1
  5756. >>> ## the random seed is change to 1
  5757. """
  5758. return self._seed
  5759. @property
  5760. def num_blocks(self):
  5761. """
  5762. The number of :ref:`api_guide_Block_en` in this Program.
  5763. .. note::
  5764. This API has no effect in Dygraph mode.
  5765. Returns:
  5766. int(Platform-dependent size): num of :ref:`api_guide_Block_en` in current Program
  5767. Examples:
  5768. .. code-block:: python
  5769. >>> import paddle
  5770. >>> import paddle.static as static
  5771. >>> paddle.enable_static()
  5772. >>> prog = static.default_main_program()
  5773. >>> num_blocks = prog.num_blocks
  5774. >>> print(num_blocks)
  5775. 1
  5776. """
  5777. return self.desc.num_blocks()
  5778. @random_seed.setter
  5779. def random_seed(self, seed):
  5780. if not isinstance(seed, int):
  5781. raise ValueError(
  5782. "Program.random_seed's input seed must be an integer, but received %s."
  5783. % type(seed)
  5784. )
  5785. self._seed = seed
  5786. def __repr__(self):
  5787. return self.__str__()
  5788. def global_block(self):
  5789. """
  5790. .. note::
  5791. This API has no effect in Dygraph mode.
  5792. Get the first :ref:`api_guide_Block_en` of this Program.
  5793. Returns:
  5794. :ref:`api_guide_Block_en`: The first :ref:`api_guide_Block_en` of this Program.
  5795. Examples:
  5796. .. code-block:: python
  5797. >>> import paddle
  5798. >>> import paddle.static as static
  5799. >>> paddle.enable_static()
  5800. >>> prog = static.default_main_program()
  5801. >>> gb_block = prog.global_block()
  5802. >>> print(gb_block)
  5803. """
  5804. return self.blocks[0]
  5805. def block(self, index):
  5806. """
  5807. .. note::
  5808. This API has no effect in Dygraph mode.
  5809. Get the :code:`index` :ref:`api_guide_Block_en` of this Program
  5810. Args:
  5811. index (int): The index of :ref:`api_guide_Block_en` to get
  5812. Returns:
  5813. :ref:`api_guide_Block_en`: The :code:`index` block
  5814. Examples:
  5815. .. code-block:: python
  5816. >>> import paddle
  5817. >>> import paddle.static as static
  5818. >>> paddle.enable_static()
  5819. >>> prog = static.default_main_program()
  5820. >>> block_0 = prog.block(0)
  5821. >>> print(block_0)
  5822. """
  5823. return self.blocks[index]
  5824. def current_block(self):
  5825. """
  5826. .. note::
  5827. This API has no effect in Dygraph mode.
  5828. Get the current :ref:`api_guide_Block_en` . The :code:`current` :ref:`api_guide_Block_en`
  5829. is the :ref:`api_guide_Block_en` to append operators.
  5830. Returns:
  5831. :ref:`api_guide_Block_en`: The :code:`index` :ref:`api_guide_Block_en`
  5832. Examples:
  5833. .. code-block:: python
  5834. >>> import paddle
  5835. >>> import paddle.static as static
  5836. >>> paddle.enable_static()
  5837. >>> prog = static.default_main_program()
  5838. >>> current_blk = prog.current_block()
  5839. >>> print(current_blk)
  5840. """
  5841. return self.blocks[self.current_block_idx]
  5842. def _create_block(self, parent_idx=None):
  5843. """
  5844. Create a new block with the :code:`parent_idx` and change the current block
  5845. to new block.
  5846. Args:
  5847. parent_idx(int): The parent block index.
  5848. Returns:
  5849. Block: The new block.
  5850. """
  5851. new_block_idx = len(self.blocks)
  5852. parent = (
  5853. self.current_block()
  5854. if parent_idx is None
  5855. else self.block(parent_idx)
  5856. )
  5857. self.desc.append_block(parent.desc)
  5858. self.current_block_idx = new_block_idx
  5859. self.blocks.append(Block(self, self.current_block_idx))
  5860. return self.current_block()
  5861. def _roll_to_global_block(self):
  5862. self.current_block_idx = 0
  5863. def _rollback(self):
  5864. """
  5865. Exit a code block, i.e., roll back to the parent block.
  5866. Returns:
  5867. None
  5868. """
  5869. self.current_block_idx = self.current_block().parent_idx
  5870. def _sync_with_cpp(self):
  5871. """
  5872. Synchronize Python instance to its binding C++ object instance.
  5873. If the program is modified in C++ space, this method should be invoked.
  5874. Notes: This is a very low level API. Users should not invoke it
  5875. directly.
  5876. Returns:
  5877. None
  5878. """
  5879. for block_idx in range(len(self.blocks), self.desc.num_blocks()):
  5880. self.blocks.append(Block(self, block_idx))
  5881. for block in self.blocks:
  5882. block._sync_with_cpp()
  5883. def _copy_param_info_from(self, other):
  5884. """
  5885. Copy the information of parameters from other program.
  5886. Notes: This is a very low level API. Users should not invoke it
  5887. directly.
  5888. Args:
  5889. other(Program): Other program
  5890. Returns:
  5891. None
  5892. """
  5893. if not isinstance(other, Program):
  5894. raise TypeError(
  5895. "Function Program._copy_param_info_from() needs to pass in a source Program, but received %s"
  5896. % type(other)
  5897. )
  5898. self.global_block()._copy_param_info_from(other.global_block())
  5899. def _copy_dist_param_info_from(self, other):
  5900. """
  5901. Copy the information of distributed information from other program.
  5902. Args:
  5903. other(Program): Other program
  5904. Returns:
  5905. None
  5906. """
  5907. if not isinstance(other, Program):
  5908. raise TypeError(
  5909. "Function Program._copy_param_info_from() needs to pass in a source Program, but received %s"
  5910. % type(other)
  5911. )
  5912. self._is_distributed = other._is_distributed
  5913. self._is_chief = other._is_chief
  5914. self._parameters_on_pservers = other._parameters_on_pservers
  5915. self._endpoints = other._endpoints
  5916. self._ps_endpoint = other._ps_endpoint
  5917. self._distributed_lookup_table = other._distributed_lookup_table
  5918. def _copy_data_info_from(self, other, pruned_origin_block_id_map=None):
  5919. """
  5920. Copy the information of data variables from other program.
  5921. Notes: This is a very low level API. Users should not invoke it
  5922. directly.
  5923. Args:
  5924. other(Program): Other program
  5925. pruned_origin_block_id_map(dict{int:int}): A dict which maps the block id in program
  5926. self to the block id in program other. For example, {0:0, 1:1, 2:3} means block 0 in self is
  5927. cloned from block 0 in other, etc. Default is None, which means default mapped,
  5928. {0:0, 1:1,..., n:n}.
  5929. Returns:
  5930. None
  5931. """
  5932. if not isinstance(other, Program):
  5933. raise TypeError(
  5934. "Function Program._copy_param_info_from() needs to pass in a source Program, but received %s"
  5935. % type(other)
  5936. )
  5937. if not pruned_origin_block_id_map:
  5938. pruned_origin_block_id_map = {
  5939. i: i for i in range(self.desc.num_blocks())
  5940. }
  5941. # NOTE(zhiqiu): All vars in cloned program exist in original program.
  5942. # The reverse is not true, due to backward pruning.
  5943. for i, block in enumerate(self.blocks):
  5944. other_block = other.blocks[pruned_origin_block_id_map[i]]
  5945. for var in list(block.vars.values()):
  5946. other_var = other_block.var(var.name)
  5947. if other_var.is_data:
  5948. var.is_data = True
  5949. if other_var.desc.need_check_feed():
  5950. var.desc.set_need_check_feed(True)
  5951. if other_var.stop_gradient:
  5952. var.stop_gradient = True
  5953. def _copy_operator_info_from(self, other: Program):
  5954. """
  5955. Copy the information of Operator information from other program.
  5956. Args:
  5957. other(Program): Other program
  5958. Returns:
  5959. None
  5960. """
  5961. if not isinstance(other, Program):
  5962. raise TypeError(
  5963. f"Function Program._copy_operator_info_from() needs to pass in a source Program, but received {type(other)}"
  5964. )
  5965. for dst_block, src_block in zip(self.blocks, other.blocks):
  5966. for dst_op, src_op in zip(dst_block.ops, src_block.ops):
  5967. dst_op.set_amp_options(src_op.amp_options)
  5968. dst_op.struct_name = src_op.struct_name
  5969. def list_vars(self):
  5970. """
  5971. Get all Tensors from this Program. A iterable object is returned.
  5972. Returns:
  5973. iterable Tensors: The Generator will yield every Tensor in this program.
  5974. Examples:
  5975. .. code-block:: python
  5976. >>> import paddle
  5977. >>> import paddle.static as static
  5978. >>> paddle.enable_static()
  5979. >>> prog = static.default_main_program()
  5980. >>> img = static.data(name='img', shape=[None, 1,28,28], dtype='float32')
  5981. >>> label = static.data(name='label', shape=[None,1], dtype='int64')
  5982. >>> for var in prog.list_vars():
  5983. ... print(var)
  5984. >>> # var img : LOD_TENSOR.shape(-1, 1, 28, 28).dtype(float32).stop_gradient(True)
  5985. >>> # var label : LOD_TENSOR.shape(-1, 1).dtype(int64).stop_gradient(True)
  5986. """
  5987. for each_block in self.blocks:
  5988. yield from list(each_block.vars.values())
  5989. def all_parameters(self):
  5990. """
  5991. Get all :ref:`api_guide_parameter_en` from this Program. A list object is returned.
  5992. Returns:
  5993. list[ :ref:`api_guide_parameter_en` ]: The list contains all parameters in this program.
  5994. Examples:
  5995. .. code-block:: python
  5996. >>> import paddle
  5997. >>> import paddle.static as static
  5998. >>> paddle.enable_static()
  5999. >>> program = static.default_main_program()
  6000. >>> data = static.data(name='x', shape=[None, 13], dtype='float32')
  6001. >>> hidden = static.nn.fc(x=data, size=10)
  6002. >>> loss = paddle.mean(hidden)
  6003. >>> paddle.optimizer.SGD(learning_rate=0.01).minimize(loss)
  6004. >>> for param in program.all_parameters():
  6005. ... print(param)
  6006. >>> # Here will print all parameters in current program, in this example,
  6007. >>> # the result is like:
  6008. >>> #
  6009. >>> # persist trainable param fc_0.w_0 : LOD_TENSOR.shape(13, 10).dtype(float32).stop_gradient(False)
  6010. >>> # persist trainable param fc_0.b_0 : LOD_TENSOR.shape(10,).dtype(float32).stop_gradient(False)
  6011. >>> #
  6012. >>> # Here print(param) will print out all the properties of a parameter,
  6013. >>> # including name, type and persistable, you can access to specific
  6014. >>> # property of a parameter, such as param.name, param.type
  6015. """
  6016. parameters = []
  6017. for each_block in self.blocks:
  6018. parameters.extend(each_block.all_parameters())
  6019. return parameters
  6020. def state_dict(self, mode="all", scope=None):
  6021. """
  6022. Get parameters and persistable buffers of program as a dict. The key is the name of the parameter or the name of the buffer.
  6023. The value is the tensor of this variable in the given scope.
  6024. .. note::
  6025. This function MUST called after run start_up_program
  6026. Args:
  6027. mode(str, optional): Source of the obtained parameters and buffers.
  6028. 'opt' : The return value only contains the variable in the optimizer.
  6029. 'param' : The return value only contains the variable in the network, not the variable in the optimizer.
  6030. 'all' : The return value contains the variable in the network and optimizer.
  6031. Default: 'all'
  6032. scope(Scope, optional) : If scope is None, state_dict will be set to global scope
  6033. obtained through 'paddle.static.global_scope()'. Otherwise, value will be set to scope.
  6034. Default: None
  6035. Returns:
  6036. dict: a dict contains the parameters and persistable buffers.
  6037. Examples:
  6038. .. code-block:: python
  6039. >>> import paddle
  6040. >>> import paddle.static as static
  6041. >>> paddle.enable_static()
  6042. >>> x = static.data(name="x", shape=[10, 10], dtype='float32')
  6043. >>> y = static.nn.fc(x, 10)
  6044. >>> z = static.nn.fc(y, 10)
  6045. >>> place = paddle.CPUPlace()
  6046. >>> exe = static.Executor(place)
  6047. >>> exe.run(static.default_startup_program())
  6048. >>> prog = static.default_main_program()
  6049. >>> path = "./temp/model.pdparams"
  6050. >>> paddle.save(prog.state_dict(), path)
  6051. """
  6052. # The 'framework' is a low-level module, and 'executor'
  6053. # can not be imported at the beginning of this file.
  6054. # Therefore, the above two modules are dynamically imported.
  6055. from .executor import global_scope
  6056. if scope is not None and not isinstance(scope, core._Scope):
  6057. raise TypeError(
  6058. f"`scope` should be None or `paddle.static.Scope'` type, but received {type(scope)}."
  6059. )
  6060. if scope is None:
  6061. scope = global_scope()
  6062. if not isinstance(mode, str):
  6063. raise TypeError(
  6064. f"Type of `mode` should be string, but received {type(mode)}."
  6065. )
  6066. def is_parameter(var):
  6067. return isinstance(var, Parameter)
  6068. def is_persistable(var):
  6069. if (
  6070. var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH
  6071. or var.desc.type() == core.VarDesc.VarType.FETCH_LIST
  6072. or var.desc.type() == core.VarDesc.VarType.READER
  6073. ):
  6074. return False
  6075. return var.persistable
  6076. def is_belong_to_optimizer(var):
  6077. if not (isinstance(var, Parameter) or var.desc.need_check_feed()):
  6078. return is_persistable(var)
  6079. return False
  6080. def condition(var):
  6081. if mode == "param":
  6082. return is_parameter(var)
  6083. elif mode == "opt":
  6084. return is_belong_to_optimizer(var)
  6085. elif mode == "all":
  6086. return is_parameter(var) or is_belong_to_optimizer(var)
  6087. else:
  6088. raise ValueError(
  6089. f"`mode` string should be 'param', 'opt' or 'all', but received {mode}."
  6090. )
  6091. var_list = filter(condition, self.list_vars())
  6092. state_dict = {}
  6093. for var in var_list:
  6094. var_temp = scope.find_var(var.name)
  6095. if var_temp is None:
  6096. raise ValueError(
  6097. f"Can not find Variable '{var.name}' in the scope. Make sure it is initialized"
  6098. )
  6099. state_dict[var.name] = var_temp.get_tensor()
  6100. return state_dict
  6101. def set_state_dict(self, state_dict, scope=None):
  6102. """
  6103. Set parameters and persistable buffers in state_dict to program.
  6104. An exception will throw if shape or dtype of the parameters is not match.
  6105. .. note::
  6106. This function MUST called after run start_up_program
  6107. Args:
  6108. state_dict(dict): the dict store parameters and persistable buffers.
  6109. The key is the name of the parameter or the name of the buffer.
  6110. The value is the tensor of this variable in the given scope.
  6111. scope(Scope, optional) : If scope is None, state_dict will be set to global scope
  6112. obtained through 'paddle.static.global_scope()'. Otherwise, value will be set to scope.
  6113. Default: None
  6114. Returns:
  6115. None
  6116. Examples:
  6117. .. code-block:: python
  6118. >>> import paddle
  6119. >>> import paddle.static as static
  6120. >>> paddle.enable_static()
  6121. >>> x = static.data(name="x", shape=[10, 10], dtype='float32')
  6122. >>> y = static.nn.fc(x, 10)
  6123. >>> z = static.nn.fc(y, 10)
  6124. >>> place = paddle.CPUPlace()
  6125. >>> exe = static.Executor(place)
  6126. >>> exe.run(static.default_startup_program())
  6127. >>> prog = static.default_main_program()
  6128. >>> path = "./temp/model.pdparams"
  6129. >>> paddle.save(prog.state_dict(), path)
  6130. >>> state_dict_load = paddle.load(path)
  6131. >>> prog.set_state_dict(state_dict_load)
  6132. """
  6133. if not isinstance(state_dict, dict):
  6134. raise TypeError(
  6135. f"Type of `state_dict` should be dict, but received {type(state_dict)}."
  6136. )
  6137. vars_dict = {var.name: var for var in self.list_vars()}
  6138. condition = (
  6139. True if "StructuredToParameterName@@" in state_dict else False
  6140. )
  6141. for name, value in state_dict.items():
  6142. if condition:
  6143. if name == "StructuredToParameterName@@":
  6144. continue
  6145. if name in state_dict["StructuredToParameterName@@"]:
  6146. name = state_dict["StructuredToParameterName@@"][name]
  6147. if name in vars_dict:
  6148. try:
  6149. vars_dict[name].set_value(value, scope)
  6150. except ValueError as err:
  6151. warnings.warn(f"Skip loading for '{name}'. " + str(err))
  6152. except TypeError as err:
  6153. warnings.warn(f"Skip loading for '{name}'. " + str(err))
  6154. else:
  6155. warnings.warn(
  6156. f"Skip loading for '{name}'. Because '{name}' not in the program."
  6157. )
  6158. class Parameter(Variable, metaclass=ParameterMetaClass):
  6159. """
  6160. Parameter is derived from Variable. A parameter is a persistable
  6161. Variable, and will be updated by optimizers after each iteration.
  6162. The training of a neural network is essentially the updating of
  6163. its parameters.
  6164. Relative to a general Variable, a Parameter has several its own
  6165. member variables:
  6166. Args:
  6167. trainable(bool): True if the parameter need to be updated after
  6168. iterations.
  6169. optimize_attr(map): Parameter attributes related with optimizing.
  6170. Currently, it only contains 'learning_rate'.
  6171. Default: {'learning_rate': 1.0}
  6172. regularizer(WeightDecayRegularizer): The Regularizer which will
  6173. be applied on the parameter. Default: None
  6174. do_model_average(bool): True if the model average strategy will
  6175. be applied on this parameter.
  6176. need_clip (bool): Whether the parameter gradient need to be clipped
  6177. in optimizer. Default is True.
  6178. """
  6179. def __init__(
  6180. self,
  6181. block,
  6182. shape,
  6183. dtype,
  6184. type=core.VarDesc.VarType.LOD_TENSOR,
  6185. **kwargs,
  6186. ):
  6187. if shape is None:
  6188. raise ValueError("The shape of Parameter should not be None")
  6189. if dtype is None:
  6190. raise ValueError("The dtype of Parameter should not be None")
  6191. for each in shape:
  6192. if each < 0:
  6193. raise ValueError(
  6194. "Each dimension of shape for Parameter must be greater than 0, but received %s"
  6195. % list(shape)
  6196. )
  6197. Variable.__init__(
  6198. self,
  6199. block,
  6200. persistable=True,
  6201. shape=shape,
  6202. dtype=dtype,
  6203. type=type,
  6204. **kwargs,
  6205. )
  6206. self.trainable = kwargs.get("trainable", True)
  6207. self.stop_gradient = not self.trainable
  6208. self.optimize_attr = kwargs.get("optimize_attr", {"learning_rate": 1.0})
  6209. self.regularizer = kwargs.get("regularizer", None)
  6210. self.do_model_average = kwargs.get("do_model_average", None)
  6211. self.need_clip = kwargs.get("need_clip", True)
  6212. self.is_distributed = False
  6213. self.is_parameter = True
  6214. def __str__(self):
  6215. return self._to_readable_code()
  6216. def to_string(self, throw_on_error, with_details=False):
  6217. """
  6218. To debug string.
  6219. Args:
  6220. throw_on_error(bool): raise exception when self is not initialized
  6221. when throw_on_error is True
  6222. with_details(bool): more details about variables and parameters
  6223. (e.g. trainable, optimize_attr, ...) will be printed when with_details is True
  6224. Returns(str): The debug string.
  6225. Examples:
  6226. .. code-block:: python
  6227. >>> import paddle
  6228. >>> paddle.enable_static()
  6229. >>> prog = paddle.static.default_main_program()
  6230. >>> rlt = paddle.static.data("fake_data", shape=[-1,1,1], dtype='float32')
  6231. >>> debug_str = prog.to_string(throw_on_error=True, with_details=False)
  6232. >>> print(debug_str)
  6233. """
  6234. assert isinstance(throw_on_error, bool) and isinstance(
  6235. with_details, bool
  6236. )
  6237. if with_details:
  6238. res_str = Variable.to_string(self, throw_on_error, True)
  6239. additional_attr = (
  6240. "trainable",
  6241. "optimize_attr",
  6242. "regularizer",
  6243. "do_model_average",
  6244. "need_clip",
  6245. )
  6246. for attr_name in additional_attr:
  6247. res_str += f"{attr_name}: {getattr(self, attr_name)}\n"
  6248. else:
  6249. res_str = Variable.to_string(self, throw_on_error, False)
  6250. return res_str
  6251. __repr__ = __str__
  6252. class EagerParamBase(core.eager.Tensor):
  6253. """
  6254. EagerParamBase is derived from Tensor( Which is the concept in Eager-Dygraph Mode).
  6255. A EagerParamBase is a persistable Tensor, and will be updated by optimizers
  6256. after each iteration.
  6257. The training of a neural network is essentially the updating of
  6258. its EagerParamBase.
  6259. Relative to a general Tensor, a EagerParamBase has several its own
  6260. member variables:
  6261. Args:
  6262. trainable(bool): True if the EagerParamBase need to be updated after
  6263. iterations.
  6264. optimize_attr(map): EagerParamBase attributes related with optimizing.
  6265. Currently, it only contains 'learning_rate'.
  6266. Default: {'learning_rate': 1.0}
  6267. regularizer(WeightDecayRegularizer): The Regularizer which will
  6268. be applied on the EagerParamBase. Default: None
  6269. do_model_average(bool): True if the model average strategy will
  6270. be applied on this EagerParamBase.
  6271. need_clip (bool): Whether the parameter gradient need to be clipped
  6272. in optimizer. Default is True.
  6273. """
  6274. @dygraph_only
  6275. def __init__(self, shape, dtype, **kwargs):
  6276. if shape is None:
  6277. raise ValueError("The shape of Parameter should not be None")
  6278. if dtype is None:
  6279. raise ValueError("The dtype of Parameter should not be None")
  6280. for each in shape:
  6281. if each < 0:
  6282. raise ValueError(
  6283. "Each dimension of shape for Parameter must be greater than 0, but received %s"
  6284. % list(shape)
  6285. )
  6286. if dtype is not None:
  6287. dtype = convert_to_proto_type(dtype)
  6288. else:
  6289. dtype = core.VarDesc.VarType.FP32
  6290. name = kwargs.get("name", unique_name.generate("_eager_param_base"))
  6291. if isinstance(shape, core.eager.Tensor):
  6292. shape = shape.numpy()
  6293. super().__init__(
  6294. dtype,
  6295. list(shape) if shape else [],
  6296. name,
  6297. core.VarDesc.VarType.LOD_TENSOR,
  6298. True,
  6299. )
  6300. self.retain_grads()
  6301. trainable = kwargs.get("trainable", True)
  6302. self.stop_gradient = not trainable
  6303. self.optimize_attr = kwargs.get("optimize_attr", {"learning_rate": 1.0})
  6304. self.regularizer = kwargs.get("regularizer", None)
  6305. self.do_model_average = kwargs.get("do_model_average", None)
  6306. self.need_clip = kwargs.get("need_clip", True)
  6307. self.is_distributed = kwargs.get("is_distributed", False)
  6308. # hook functions for lazy initialization
  6309. self._init_func = None
  6310. self._init_op_creator = None
  6311. @classmethod
  6312. def from_tensor(cls, tensor, **kwargs):
  6313. # 1. construct EagerParamBase
  6314. param = cls(tensor.shape, tensor.dtype, **kwargs)
  6315. # 2. transform data if needed
  6316. mesh = kwargs.get("process_mesh", None)
  6317. placements = kwargs.get("placements", None)
  6318. src_tensor = tensor
  6319. if mesh is not None and placements is not None:
  6320. src_tensor = core.eager.Tensor(
  6321. tensor, process_mesh=mesh, placements=placements
  6322. )
  6323. param.name = tensor.name + ".dist"
  6324. # 3. set param data
  6325. param._set_impl(src_tensor)
  6326. return param
  6327. def set_init_func(self, obj):
  6328. self._init_func = obj
  6329. @dygraph_only
  6330. def initialize(self):
  6331. assert (
  6332. self._init_func is not None
  6333. ), "Required self._init_func is not None, but received None."
  6334. self._init_func(self, None)
  6335. # clear function handle to release resource
  6336. self._init_func = None
  6337. @property
  6338. def trainable(self):
  6339. return not self.stop_gradient
  6340. @trainable.setter
  6341. def trainable(self, trainable):
  6342. if isinstance(trainable, bool):
  6343. self.stop_gradient = not trainable
  6344. else:
  6345. raise ValueError(
  6346. "The type of trainable MUST be bool, but the type is ",
  6347. type(trainable),
  6348. )
  6349. def _create_init_op(self, block):
  6350. """
  6351. Call init_op_creator function to create initializer operation in block.
  6352. """
  6353. assert (
  6354. self._init_op_creator is not None
  6355. ), "Required self._init_op_creator is not None, but received None."
  6356. self._init_op_creator(self, block)
  6357. def __str__(self):
  6358. """
  6359. Convert a EagerParamBase object to a readable string.
  6360. Returns(str): A readable string.
  6361. Examples:
  6362. .. code-block:: python
  6363. >>> import paddle
  6364. >>> linear = paddle.nn.Linear(3, 3)
  6365. >>> print(linear.weight)
  6366. >>> # doctest: +SKIP('it will be different')
  6367. Parameter containing:
  6368. Tensor(shape=[3, 3], dtype=float32, place=Place(cpu), stop_gradient=False,
  6369. [[ 0.48948765, 0.05829060, -0.25524026],
  6370. [-0.70368278, 0.52986908, -0.68742192],
  6371. [-0.54217887, 0.48439729, 0.34082305]])
  6372. """
  6373. return f"Parameter containing:\n{super().__str__()}"
  6374. def __deepcopy__(self, memo):
  6375. """
  6376. Deep copy parameter, it will always performs Tensor copy.
  6377. Examples:
  6378. .. code-block:: python
  6379. >>> import paddle
  6380. >>> import copy
  6381. >>> linear = paddle.nn.Linear(1, 3)
  6382. >>> linear_copy = copy.deepcopy(linear)
  6383. >>> print(linear.weight)
  6384. >>> # doctest: +SKIP('it will be different')
  6385. Parameter containing:
  6386. Tensor(shape=[1, 3], dtype=float32, place=Place(cpu), stop_gradient=False,
  6387. [[-0.30929261, -0.90929240, -1.07851017]])
  6388. >>> # doctest: -SKIP
  6389. >>> print(linear_copy.weight)
  6390. >>> # doctest: +SKIP('it will be different')
  6391. Parameter containing:
  6392. Tensor(shape=[1, 3], dtype=float32, place=Place(cpu), stop_gradient=False,
  6393. [[-0.30929261, -0.90929240, -1.07851017]])
  6394. """
  6395. state = copy.deepcopy(self.__dict__, memo)
  6396. state["name"] = self.name + unique_name.generate("_deepcopy")
  6397. new_param = EagerParamBase(self.shape, self.dtype, **state)
  6398. memo[id(self)] = new_param
  6399. new_param.copy_(self, True)
  6400. new_param._init_func = self._init_func
  6401. new_param._init_op_creator = self._init_op_creator
  6402. return new_param
  6403. def _copy_to(self, device, blocking):
  6404. state = copy.deepcopy(self.__dict__)
  6405. new_param = EagerParamBase(self.shape, self.dtype, **state)
  6406. core.eager.tensor_copy(self, new_param, device, blocking)
  6407. return new_param
  6408. __repr__ = __str__
  6409. # program is a global instance.
  6410. _main_program_ = Program()
  6411. _startup_program_ = Program()
  6412. _startup_program_._is_start_up_program_ = True
  6413. def default_startup_program():
  6414. """
  6415. Get default/global startup program.
  6416. The :code:`paddle.nn` function will append the initialization operators into startup program.
  6417. The :code:`startup_program` will initialize the parameters by the OPs.
  6418. This method will return the default or the current startup program. Users can use
  6419. :ref:`api_paddle_base_framework_program_guard` to switch :ref:`api_paddle_base_framework_Program` .
  6420. Returns:
  6421. Program: current default startup program.
  6422. Returns type:
  6423. Examples:
  6424. .. code-block:: python
  6425. >>> import paddle
  6426. >>> paddle.enable_static()
  6427. >>> x = paddle.static.data(name="x", shape=[-1, 784], dtype='float32')
  6428. >>> out = paddle.static.nn.fc(name="fc", x=x, size=10, activation="relu")
  6429. >>> print("main program is: {}".format(paddle.static.default_main_program()))
  6430. >>> print("start up program is: {}".format(paddle.static.default_startup_program()))
  6431. """
  6432. return _startup_program_
  6433. def default_main_program():
  6434. """
  6435. This API can be used to get ``default main program`` which store the
  6436. descriptions of Ops and tensors.
  6437. For example ``z = paddle.add(x, y)`` will create a new ``add``
  6438. Op and a new ``z`` tensor, and they will be recorded in ``default main program`` .
  6439. The ``default main program`` is the default value for ``Program`` parameter in
  6440. a lot of APIs. For example, the :code:`Executor.run()` will execute the
  6441. :code:`default_main_program` when the program is not specified.
  6442. If you want to switch the ``default main program``, you can use :ref:`api_paddle_base_framework_program_guard` .
  6443. Returns:
  6444. Program: A ``Program`` which holding the descriptions of OPs and tensors in the network.
  6445. Examples:
  6446. .. code-block:: python
  6447. >>> import paddle
  6448. >>> paddle.enable_static()
  6449. >>> # Sample Network:
  6450. >>> x = paddle.static.data(name='x', shape=[100, 100], dtype='float32')
  6451. >>> y = paddle.static.data(name='y', shape=[100, 100], dtype='float32')
  6452. >>> out = paddle.add(x, y)
  6453. >>> # print the number of blocks in the program, 1 in this case
  6454. >>> print(paddle.static.default_main_program().num_blocks)
  6455. 1
  6456. >>> # print the default_main_program
  6457. >>> print(paddle.static.default_main_program())
  6458. """
  6459. return _main_program_
  6460. def switch_main_program(program):
  6461. """
  6462. Switch the main program to a new program.
  6463. Args:
  6464. program(Program): The new main program
  6465. Returns:
  6466. Program: The previous main program
  6467. """
  6468. global _main_program_
  6469. prev_program = _main_program_
  6470. _main_program_ = program
  6471. return prev_program
  6472. def switch_startup_program(program):
  6473. """
  6474. Switch the startup program to a new program
  6475. Args:
  6476. program(Program): The new startup program
  6477. Returns:
  6478. Program: The previous startup program
  6479. """
  6480. global _startup_program_
  6481. prev_program = _startup_program_
  6482. _startup_program_ = program
  6483. return prev_program
  6484. @signature_safe_contextmanager
  6485. def program_guard(main_program, startup_program=None):
  6486. """
  6487. :api_attr: Static Graph
  6488. Change the global main program and startup program with ``with`` statement.
  6489. Layer functions in the Python ``with`` block will append operators and
  6490. Tensors to the new main programs.
  6491. Args:
  6492. main_program(Program): New main program inside ``with`` statement.
  6493. startup_program(Program, optional): New startup program inside ``with``
  6494. statement. :code:`None` means not changing startup program,
  6495. default_startup_program is still used.
  6496. Default: None.
  6497. Examples:
  6498. .. code-block:: python
  6499. :name: code-example-1
  6500. >>> import paddle
  6501. >>> paddle.enable_static()
  6502. >>> main_program = paddle.static.Program()
  6503. >>> startup_program = paddle.static.Program()
  6504. >>> with paddle.static.program_guard(main_program, startup_program):
  6505. ... data = paddle.static.data(name='image', shape=[None, 784, 784], dtype='float32')
  6506. ... hidden = paddle.static.nn.fc(x=data, size=10, activation='relu')
  6507. Notes: The temporary :code:`Program` can be used if the user does not need
  6508. to construct either of startup program or main program.
  6509. Examples:
  6510. .. code-block:: python
  6511. :name: code-example-2
  6512. >>> import paddle
  6513. >>> paddle.enable_static()
  6514. >>> main_program = paddle.static.Program()
  6515. >>> # does not care about startup program. Just pass a temporary value.
  6516. >>> with paddle.static.program_guard(main_program, paddle.static.Program()):
  6517. ... data = paddle.static.data(name='image', shape=[None, 784, 784], dtype='float32')
  6518. """
  6519. from .data_feeder import check_type
  6520. check_type(
  6521. main_program, "main_program", Program, "paddle.static.program_guard"
  6522. )
  6523. main_program = switch_main_program(main_program)
  6524. if startup_program is not None:
  6525. check_type(
  6526. startup_program,
  6527. "startup_program",
  6528. Program,
  6529. "paddle.static.program_guard",
  6530. )
  6531. # Tag the program __is_start_up as True
  6532. startup_program._is_start_up_program_ = True
  6533. startup_program = switch_startup_program(startup_program)
  6534. try:
  6535. yield
  6536. finally:
  6537. switch_main_program(main_program)
  6538. if startup_program is not None:
  6539. switch_startup_program(startup_program)
  6540. def _get_var(name, program=None):
  6541. """
  6542. Get a variable by name from the global block of a program.
  6543. Args:
  6544. name(str): name of the variable
  6545. program(Program|None): program object.
  6546. If None, default_global_program() will be used.
  6547. Returns:
  6548. Variable
  6549. """
  6550. if program is None:
  6551. program = default_main_program()
  6552. assert isinstance(name, str)
  6553. assert isinstance(program, Program)
  6554. return program.global_block().var(name)
  6555. @signature_safe_contextmanager
  6556. def dygraph_guard_if_declarative():
  6557. from .dygraph import Tracer
  6558. from .dygraph.base import in_to_static_mode
  6559. if in_to_static_mode():
  6560. # Under @paddle.jit.to_static decorator, we switch back dygraph mode temporarily.
  6561. with _dygraph_guard(tracer=Tracer()):
  6562. yield
  6563. else:
  6564. yield
  6565. @signature_safe_contextmanager
  6566. def _dygraph_guard(tracer):
  6567. tmp_tracer = global_var._dygraph_tracer_
  6568. global_var._dygraph_tracer_ = tracer
  6569. try:
  6570. yield
  6571. finally:
  6572. global_var._dygraph_tracer_ = tmp_tracer
  6573. @signature_safe_contextmanager
  6574. def _dygraph_place_guard(place):
  6575. global _global_expected_place_
  6576. tmp_place = _global_expected_place_
  6577. _global_expected_place_ = place
  6578. _set_dygraph_tracer_expected_place(place)
  6579. try:
  6580. yield
  6581. finally:
  6582. _global_expected_place_ = tmp_place
  6583. _set_dygraph_tracer_expected_place(_global_expected_place_)
  6584. def switch_device(device):
  6585. global _current_device
  6586. pre_device = _current_device
  6587. _current_device = device
  6588. return pre_device
  6589. @signature_safe_contextmanager
  6590. def device_guard(device=None):
  6591. """
  6592. Note:
  6593. The API only supports static graph mode.
  6594. A context manager that specifies the device on which the OP will be placed.
  6595. Args:
  6596. device(str|None): Specify the device to use in the context. It should be ``cpu``,
  6597. ``gpu`` or ``gpu:x``, where ``x`` is the index of the GPUs.
  6598. When it is set to 'cpu' or 'gpu', all OPs created in the context will be
  6599. placed on CPUPlace or CUDAPlace. When 'gpu' is set and the program runs on
  6600. single-card, the device index will be the same as the device on which the
  6601. executor runs. Default: None, OPs in this context will be automatically
  6602. assigned devices.
  6603. Examples:
  6604. .. code-block:: python
  6605. >>> # doctest: +REQUIRES(env:GPU)
  6606. >>> import paddle
  6607. >>> paddle.device.set_device('gpu')
  6608. >>> paddle.enable_static()
  6609. >>> support_gpu = paddle.is_compiled_with_cuda()
  6610. >>> place = paddle.CPUPlace()
  6611. >>> if support_gpu:
  6612. ... place = paddle.CUDAPlace(0)
  6613. >>> # if GPU is supported, the three OPs below will be automatically assigned to CUDAPlace(0)
  6614. >>> data1 = paddle.full(shape=[1, 3, 8, 8], fill_value=0.5, dtype='float32')
  6615. >>> data2 = paddle.full(shape=[1, 3, 64], fill_value=0.5, dtype='float32')
  6616. >>> shape = paddle.shape(data2)
  6617. >>> with paddle.static.device_guard("cpu"):
  6618. ... # Ops created here will be placed on CPUPlace
  6619. ... shape = paddle.slice(shape, axes=[0], starts=[0], ends=[4])
  6620. >>> with paddle.static.device_guard('gpu'):
  6621. ... # if GPU is supported, OPs created here will be placed on CUDAPlace(0), otherwise on CPUPlace
  6622. ... out = paddle.reshape(data1, shape=shape)
  6623. >>> exe = paddle.static.Executor(place)
  6624. >>> exe.run(paddle.static.default_startup_program())
  6625. >>> result = exe.run(fetch_list=[out])
  6626. """
  6627. index = None
  6628. if device and ":" in device:
  6629. device, index = device.split(":")
  6630. if device == "cpu":
  6631. raise ValueError("Should not set device id for cpu.")
  6632. if (
  6633. device not in ["cpu", "gpu", "xpu", "", None]
  6634. and device not in core.get_all_custom_device_type()
  6635. ):
  6636. raise ValueError(
  6637. "The Attr(device) should be 'cpu', 'xpu', 'gpu' or custom device, and it can also be empty string or None "
  6638. "when there is no need to specify device. But received %s" % device
  6639. )
  6640. if index:
  6641. device = ":".join([device, index])
  6642. pre_device = switch_device(device)
  6643. try:
  6644. yield
  6645. finally:
  6646. switch_device(pre_device)
  6647. def _switch_cuda_graph_mode(cuda_graph_attr):
  6648. global _current_cuda_graph_mode
  6649. pre_mode = _current_cuda_graph_mode
  6650. _current_cuda_graph_mode = cuda_graph_attr
  6651. return pre_mode
  6652. @signature_safe_contextmanager
  6653. def _cuda_graph_guard(cuda_graph_attr=None):
  6654. """
  6655. Note:
  6656. The API only supports static graph mode.
  6657. A context manager that specifies the cuda_graph_mode which indicating the cuda graph capture under static graph mode.
  6658. Args:
  6659. cuda_graph_attr(str|None): The cuda graph attr with the format of:
  6660. cuda_graph_capture_mode;memory_pool_id;cuda_graph_id
  6661. """
  6662. assert (
  6663. not in_dygraph_mode()
  6664. ), "cuda_graph_guard only works under static graph mode"
  6665. assert (
  6666. core.is_compiled_with_cuda()
  6667. ), "cuda_graph_guard context can be only used when Paddle is compiled with cuda"
  6668. pre_mode = _switch_cuda_graph_mode(cuda_graph_attr)
  6669. try:
  6670. yield
  6671. finally:
  6672. _switch_cuda_graph_mode(pre_mode)
  6673. def _get_paddle_place(place):
  6674. "convert the string to paddle Place"
  6675. if place is None:
  6676. return place
  6677. if isinstance(
  6678. place,
  6679. (
  6680. core.Place,
  6681. core.XPUPlace,
  6682. core.CPUPlace,
  6683. core.CUDAPinnedPlace,
  6684. core.CUDAPlace,
  6685. core.IPUPlace,
  6686. core.CustomPlace,
  6687. ),
  6688. ):
  6689. return place
  6690. if not isinstance(place, str):
  6691. raise ValueError(
  6692. "place only support string which is 'Place' and so on."
  6693. )
  6694. place = place.lower()
  6695. if place == "cpu":
  6696. return core.CPUPlace()
  6697. if place == "device":
  6698. return core.Place()
  6699. # GPU
  6700. available_gpu_place = re.match(r"gpu:\d+", place)
  6701. if place == "gpu_pinned" or place == "gpu" or available_gpu_place:
  6702. if not core.is_compiled_with_cuda():
  6703. raise ValueError(
  6704. f"The device should not be {available_gpu_place.group()}, since PaddlePaddle is "
  6705. "not compiled with CUDA"
  6706. )
  6707. if place == "gpu_pinned":
  6708. return core.CUDAPinnedPlace()
  6709. elif place == "gpu":
  6710. return core.CUDAPlace(0)
  6711. else:
  6712. place_info_list = place.split(":", 1)
  6713. device_id = place_info_list[1]
  6714. device_id = int(device_id)
  6715. return core.CUDAPlace(device_id)
  6716. # XPU
  6717. available_xpu_place = re.match(r"xpu:\d+", place)
  6718. if available_xpu_place:
  6719. if not core.is_compiled_with_xpu():
  6720. raise ValueError(
  6721. f"The device should not be {available_xpu_place.group()}, since PaddlePaddle is "
  6722. "not compiled with XPU"
  6723. )
  6724. place_info_list = place.split(":", 1)
  6725. device_id = place_info_list[1]
  6726. device_id = int(device_id)
  6727. return core.XPUPlace(device_id)
  6728. # IPU
  6729. available_ipu_place = re.match(r"ipu:\d+", place)
  6730. if available_ipu_place:
  6731. if not core.is_compiled_with_ipu():
  6732. raise ValueError(
  6733. f"The device should not be {available_ipu_place.group()}, since PaddlePaddle is "
  6734. "not compiled with IPU"
  6735. )
  6736. place_info_list = place.split(":", 1)
  6737. device_id = place_info_list[1]
  6738. device_id = int(device_id)
  6739. return core.IPUPlace(device_id)
  6740. place_info_list = place.split(":", 1)
  6741. device_type = place_info_list[0]
  6742. if device_type in core.get_all_custom_device_type():
  6743. device_id = place_info_list[1]
  6744. device_id = int(device_id)
  6745. return core.CustomPlace(device_type, device_id)
  6746. raise ValueError(
  6747. f"Paddle supports CPUPlace, CUDAPlace, CUDAPinnedPlace, XPUPlace, IPUPlace and CustomPlace, but received {place}."
  6748. )
  6749. def _get_paddle_place_list(places):
  6750. if not isinstance(places, (list, tuple)):
  6751. raise TypeError("places must to be List or Tuple")
  6752. ret = []
  6753. for p in places:
  6754. p = _get_paddle_place(p)
  6755. ret.append(p)
  6756. return ret
  6757. def dtype_to_str(in_dtype):
  6758. if in_dtype == paddle.float16:
  6759. return "fp16"
  6760. elif in_dtype == paddle.bfloat16:
  6761. return "bf16"
  6762. elif in_dtype == paddle.float32:
  6763. return "fp32"
  6764. elif in_dtype == paddle.float64:
  6765. return "fp64"
  6766. elif in_dtype == core.VarDesc.VarType.COMPLEX64:
  6767. return "complex64"
  6768. elif in_dtype == core.VarDesc.VarType.COMPLEX128:
  6769. return "complex128"
  6770. else:
  6771. raise TypeError(f"got unspport data type for promotion: {in_dtype}.")
  6772. def add_cast_for_type_promotion(op, block, idx, var_name, out_dtype):
  6773. op_device = op.attr("op_device")
  6774. cast_name = var_name.name + ".cast_" + dtype_to_str(out_dtype)
  6775. out_var = block.create_var(
  6776. name=cast_name,
  6777. dtype=out_dtype,
  6778. persistable=False,
  6779. stop_gradient=var_name.stop_gradient,
  6780. )
  6781. op_role = (
  6782. int(core.op_proto_and_checker_maker.OpRole.Forward)
  6783. if not op.has_attr("op_role")
  6784. else op.attr("op_role")
  6785. )
  6786. block._insert_op_without_sync(
  6787. idx,
  6788. type="cast",
  6789. inputs={"X": var_name},
  6790. outputs={"Out": out_var},
  6791. attrs={
  6792. "in_dtype": var_name.dtype,
  6793. "out_dtype": out_var.dtype,
  6794. "op_device": op_device,
  6795. "op_role": op_role,
  6796. },
  6797. )
  6798. op.desc._rename_input(var_name.name, out_var.name)
  6799. def can_skip_promote(op, device):
  6800. # Only GPU/XPU elementwise_add kernel supports the pattern "float + half".
  6801. if device not in ['GPU', 'XPU']:
  6802. return False
  6803. if op.type != "elementwise_add":
  6804. return False
  6805. input_x_dtype = op.block._find_var_recursive(op.input('X')[0]).dtype
  6806. input_y_dtype = op.block._find_var_recursive(op.input('Y')[0]).dtype
  6807. if input_x_dtype == paddle.float32 and (
  6808. input_y_dtype in [paddle.float16, paddle.bfloat16]
  6809. ):
  6810. return True
  6811. return False
  6812. def process_type_promotion(program):
  6813. # Get _current_expected_place place
  6814. device = None
  6815. if core.is_compiled_with_cuda() and isinstance(
  6816. _current_expected_place(), core.CUDAPlace
  6817. ):
  6818. device = 'GPU'
  6819. elif core.is_compiled_with_xpu() and isinstance(
  6820. _current_expected_place(), core.XPUPlace
  6821. ):
  6822. device = 'XPU'
  6823. org_program = program
  6824. if program is None:
  6825. program = default_main_program()
  6826. # not support pir for now
  6827. if not isinstance(program, Program):
  6828. return org_program
  6829. global_block = program.global_block()
  6830. all_params = global_block.all_parameters()
  6831. for block in program.blocks:
  6832. ops = block.ops
  6833. idx = 0
  6834. while idx < len(ops):
  6835. op = ops[idx]
  6836. var_name = None
  6837. all_dtypes = []
  6838. all_input_name_need_cast = []
  6839. need_transed_var_names = SUPPORT_PROMOTION_OPS_AND_INPUTNAME.get(
  6840. op.type, None
  6841. )
  6842. # type promotion only support some dyadic api
  6843. if need_transed_var_names is None or can_skip_promote(op, device):
  6844. idx += 1
  6845. continue
  6846. # get all dtype and input_name
  6847. for input_idx in range(len(op.input_arg_names)):
  6848. if op.input_names[input_idx] in need_transed_var_names:
  6849. input_arg_name = op.input_arg_names[input_idx]
  6850. all_dtypes.append(
  6851. op.block._var_recursive(input_arg_name).dtype
  6852. )
  6853. all_input_name_need_cast.append(input_arg_name)
  6854. # only support promote between float
  6855. if len(all_dtypes) == 2 and core.need_type_promotion(
  6856. op.type, *all_dtypes
  6857. ):
  6858. common_dtype = core.get_promote_dtype(op.type, *all_dtypes)
  6859. for input_name_need_cast in all_input_name_need_cast:
  6860. var_name = op.block._var_recursive(input_name_need_cast)
  6861. if var_name.dtype != common_dtype:
  6862. # add cast op for different dtype
  6863. add_cast_for_type_promotion(
  6864. op,
  6865. block,
  6866. idx,
  6867. var_name,
  6868. common_dtype,
  6869. )
  6870. idx += 1
  6871. idx += 1
  6872. return program