| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095309630973098309931003101310231033104310531063107310831093110311131123113311431153116311731183119312031213122312331243125312631273128312931303131313231333134313531363137313831393140314131423143314431453146314731483149315031513152315331543155315631573158315931603161316231633164316531663167316831693170317131723173317431753176317731783179318031813182318331843185318631873188318931903191319231933194319531963197319831993200320132023203320432053206320732083209321032113212321332143215321632173218321932203221322232233224322532263227322832293230323132323233323432353236323732383239324032413242324332443245324632473248324932503251325232533254325532563257325832593260326132623263326432653266326732683269327032713272327332743275327632773278327932803281328232833284328532863287328832893290329132923293329432953296329732983299330033013302330333043305330633073308330933103311331233133314331533163317331833193320332133223323332433253326332733283329333033313332333333343335333633373338333933403341334233433344334533463347334833493350335133523353335433553356335733583359336033613362336333643365336633673368336933703371337233733374337533763377337833793380338133823383338433853386338733883389339033913392339333943395339633973398339934003401340234033404340534063407340834093410341134123413341434153416341734183419342034213422342334243425342634273428342934303431343234333434343534363437343834393440344134423443344434453446344734483449345034513452345334543455345634573458345934603461346234633464346534663467346834693470347134723473347434753476347734783479348034813482348334843485348634873488348934903491349234933494349534963497349834993500350135023503350435053506350735083509351035113512351335143515351635173518351935203521352235233524352535263527352835293530353135323533353435353536353735383539354035413542354335443545354635473548354935503551355235533554355535563557355835593560356135623563356435653566356735683569357035713572357335743575357635773578357935803581358235833584358535863587358835893590359135923593359435953596359735983599360036013602360336043605360636073608360936103611361236133614361536163617361836193620362136223623362436253626362736283629363036313632363336343635363636373638363936403641364236433644364536463647364836493650365136523653365436553656365736583659366036613662366336643665366636673668366936703671367236733674367536763677367836793680368136823683368436853686368736883689369036913692369336943695369636973698369937003701370237033704370537063707370837093710371137123713371437153716371737183719372037213722372337243725372637273728372937303731373237333734373537363737373837393740374137423743374437453746374737483749375037513752375337543755375637573758375937603761376237633764376537663767376837693770377137723773377437753776377737783779378037813782378337843785378637873788378937903791379237933794379537963797379837993800380138023803380438053806380738083809381038113812381338143815381638173818381938203821382238233824382538263827382838293830383138323833383438353836383738383839384038413842384338443845384638473848384938503851385238533854385538563857385838593860386138623863386438653866386738683869387038713872387338743875387638773878387938803881388238833884388538863887388838893890389138923893389438953896389738983899390039013902390339043905390639073908390939103911391239133914391539163917391839193920392139223923392439253926392739283929393039313932393339343935393639373938393939403941394239433944394539463947394839493950395139523953395439553956395739583959396039613962396339643965396639673968396939703971397239733974397539763977397839793980398139823983398439853986398739883989399039913992399339943995399639973998399940004001400240034004400540064007400840094010401140124013401440154016401740184019402040214022402340244025402640274028402940304031403240334034403540364037403840394040404140424043404440454046404740484049405040514052405340544055405640574058405940604061406240634064406540664067406840694070407140724073407440754076407740784079408040814082408340844085408640874088408940904091409240934094409540964097409840994100410141024103410441054106410741084109411041114112411341144115411641174118411941204121412241234124412541264127412841294130413141324133413441354136413741384139414041414142414341444145414641474148414941504151415241534154 |
- # Copyright 2020 The HuggingFace Team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import ast
- import collections
- import contextlib
- import copy
- import doctest
- import functools
- import gc
- import importlib
- import inspect
- import logging
- import multiprocessing
- import os
- import re
- import shlex
- import shutil
- import subprocess
- import sys
- import tempfile
- import threading
- import time
- import traceback
- import types
- import unittest
- from collections import UserDict, defaultdict
- from collections.abc import Generator, Iterable, Iterator, Mapping
- from dataclasses import MISSING, fields
- from functools import cache, wraps
- from io import StringIO
- from pathlib import Path
- from typing import Any, Callable, Optional, Union
- from unittest import mock
- from unittest.mock import patch
- import huggingface_hub.utils
- import requests
- import urllib3
- from huggingface_hub import delete_repo
- from packaging import version
- from transformers import Trainer
- from transformers import logging as transformers_logging
- from .integrations import (
- is_clearml_available,
- is_optuna_available,
- is_ray_available,
- is_sigopt_available,
- is_swanlab_available,
- is_tensorboard_available,
- is_trackio_available,
- is_wandb_available,
- )
- from .integrations.deepspeed import is_deepspeed_available
- from .utils import (
- ACCELERATE_MIN_VERSION,
- GGUF_MIN_VERSION,
- TRITON_MIN_VERSION,
- is_accelerate_available,
- is_apex_available,
- is_apollo_torch_available,
- is_aqlm_available,
- is_auto_awq_available,
- is_auto_gptq_available,
- is_auto_round_available,
- is_av_available,
- is_bitsandbytes_available,
- is_bitsandbytes_multi_backend_available,
- is_bs4_available,
- is_compressed_tensors_available,
- is_cv2_available,
- is_cython_available,
- is_decord_available,
- is_detectron2_available,
- is_eetq_available,
- is_essentia_available,
- is_faiss_available,
- is_fbgemm_gpu_available,
- is_flash_attn_2_available,
- is_flash_attn_3_available,
- is_flax_available,
- is_flute_available,
- is_fp_quant_available,
- is_fsdp_available,
- is_ftfy_available,
- is_g2p_en_available,
- is_galore_torch_available,
- is_gguf_available,
- is_gptqmodel_available,
- is_grokadamw_available,
- is_hadamard_available,
- is_hqq_available,
- is_huggingface_hub_greater_or_equal,
- is_ipex_available,
- is_jinja_available,
- is_jumanpp_available,
- is_keras_nlp_available,
- is_kernels_available,
- is_levenshtein_available,
- is_librosa_available,
- is_liger_kernel_available,
- is_lomo_available,
- is_mistral_common_available,
- is_natten_available,
- is_nltk_available,
- is_onnx_available,
- is_openai_available,
- is_optimum_available,
- is_optimum_quanto_available,
- is_pandas_available,
- is_peft_available,
- is_phonemizer_available,
- is_pretty_midi_available,
- is_psutil_available,
- is_pyctcdecode_available,
- is_pytesseract_available,
- is_pytest_available,
- is_pytorch_quantization_available,
- is_quark_available,
- is_qutlass_available,
- is_rjieba_available,
- is_sacremoses_available,
- is_safetensors_available,
- is_schedulefree_available,
- is_scipy_available,
- is_sentencepiece_available,
- is_seqio_available,
- is_spacy_available,
- is_speech_available,
- is_spqr_available,
- is_sudachi_available,
- is_sudachi_projection_available,
- is_tf_available,
- is_tiktoken_available,
- is_timm_available,
- is_tokenizers_available,
- is_torch_available,
- is_torch_bf16_available_on_device,
- is_torch_bf16_gpu_available,
- is_torch_fp16_available_on_device,
- is_torch_greater_or_equal,
- is_torch_hpu_available,
- is_torch_mlu_available,
- is_torch_neuroncore_available,
- is_torch_npu_available,
- is_torch_optimi_available,
- is_torch_tensorrt_fx_available,
- is_torch_tf32_available,
- is_torch_xla_available,
- is_torch_xpu_available,
- is_torchao_available,
- is_torchaudio_available,
- is_torchcodec_available,
- is_torchdynamo_available,
- is_torchvision_available,
- is_triton_available,
- is_vision_available,
- is_vptq_available,
- strtobool,
- )
- if is_accelerate_available():
- from accelerate.state import AcceleratorState, PartialState
- from accelerate.utils.imports import is_fp8_available
- if is_pytest_available():
- from _pytest.doctest import (
- Module,
- _get_checker,
- _get_continue_on_failure,
- _get_runner,
- _is_mocked,
- _patch_unwrap_mock_aware,
- get_optionflags,
- )
- from _pytest.outcomes import skip
- from _pytest.pathlib import import_path
- from pytest import DoctestItem
- else:
- Module = object
- DoctestItem = object
- SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy"
- DUMMY_UNKNOWN_IDENTIFIER = "julien-c/dummy-unknown"
- DUMMY_DIFF_TOKENIZER_IDENTIFIER = "julien-c/dummy-diff-tokenizer"
- # Used to test Auto{Config, Model, Tokenizer} model_type detection.
- # Used to test the hub
- USER = "__DUMMY_TRANSFORMERS_USER__"
- ENDPOINT_STAGING = "https://hub-ci.huggingface.co"
- # Not critical, only usable on the sandboxed CI instance.
- TOKEN = "hf_94wBhPGp6KrrTH3KDchhKpRxZwd6dmHWLL"
- # Used in CausalLMModelTester (and related classes/methods) to infer the common model classes from the base model class
- _COMMON_MODEL_NAMES_MAP = {
- "config_class": "Config",
- "causal_lm_class": "ForCausalLM",
- "question_answering_class": "ForQuestionAnswering",
- "sequence_classification_class": "ForSequenceClassification",
- "token_classification_class": "ForTokenClassification",
- }
- if is_torch_available():
- import torch
- IS_ROCM_SYSTEM = torch.version.hip is not None
- IS_CUDA_SYSTEM = torch.version.cuda is not None
- IS_XPU_SYSTEM = getattr(torch.version, "xpu", None) is not None
- else:
- IS_ROCM_SYSTEM = False
- IS_CUDA_SYSTEM = False
- IS_XPU_SYSTEM = False
- logger = transformers_logging.get_logger(__name__)
- def parse_flag_from_env(key, default=False):
- try:
- value = os.environ[key]
- except KeyError:
- # KEY isn't set, default to `default`.
- _value = default
- else:
- # KEY is set, convert it to True or False.
- try:
- _value = strtobool(value)
- except ValueError:
- # More values are supported, but let's keep the message simple.
- raise ValueError(f"If set, {key} must be yes or no.")
- return _value
- def parse_int_from_env(key, default=None):
- try:
- value = os.environ[key]
- except KeyError:
- _value = default
- else:
- try:
- _value = int(value)
- except ValueError:
- raise ValueError(f"If set, {key} must be a int.")
- return _value
- _run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
- _run_flaky_tests = parse_flag_from_env("RUN_FLAKY", default=True)
- _run_custom_tokenizers = parse_flag_from_env("RUN_CUSTOM_TOKENIZERS", default=False)
- _run_staging = parse_flag_from_env("HUGGINGFACE_CO_STAGING", default=False)
- _run_pipeline_tests = parse_flag_from_env("RUN_PIPELINE_TESTS", default=True)
- _run_agent_tests = parse_flag_from_env("RUN_AGENT_TESTS", default=False)
- def is_staging_test(test_case):
- """
- Decorator marking a test as a staging test.
- Those tests will run using the staging environment of huggingface.co instead of the real model hub.
- """
- if not _run_staging:
- return unittest.skip(reason="test is staging test")(test_case)
- else:
- try:
- import pytest # We don't need a hard dependency on pytest in the main library
- except ImportError:
- return test_case
- else:
- return pytest.mark.is_staging_test()(test_case)
- def is_pipeline_test(test_case):
- """
- Decorator marking a test as a pipeline test. If RUN_PIPELINE_TESTS is set to a falsy value, those tests will be
- skipped.
- """
- if not _run_pipeline_tests:
- return unittest.skip(reason="test is pipeline test")(test_case)
- else:
- try:
- import pytest # We don't need a hard dependency on pytest in the main library
- except ImportError:
- return test_case
- else:
- return pytest.mark.is_pipeline_test()(test_case)
- def is_agent_test(test_case):
- """
- Decorator marking a test as an agent test. If RUN_TOOL_TESTS is set to a falsy value, those tests will be skipped.
- """
- if not _run_agent_tests:
- return unittest.skip(reason="test is an agent test")(test_case)
- else:
- try:
- import pytest # We don't need a hard dependency on pytest in the main library
- except ImportError:
- return test_case
- else:
- return pytest.mark.is_agent_test()(test_case)
- def slow(test_case):
- """
- Decorator marking a test as slow.
- Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them.
- """
- return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)
- def tooslow(test_case):
- """
- Decorator marking a test as too slow.
- Slow tests are skipped while they're in the process of being fixed. No test should stay tagged as "tooslow" as
- these will not be tested by the CI.
- """
- return unittest.skip(reason="test is too slow")(test_case)
- def skip_if_not_implemented(test_func):
- @functools.wraps(test_func)
- def wrapper(*args, **kwargs):
- try:
- return test_func(*args, **kwargs)
- except NotImplementedError as e:
- raise unittest.SkipTest(f"Test skipped due to NotImplementedError: {e}")
- return wrapper
- def apply_skip_if_not_implemented(cls):
- """
- Class decorator to apply @skip_if_not_implemented to all test methods.
- """
- for attr_name in dir(cls):
- if attr_name.startswith("test_"):
- attr = getattr(cls, attr_name)
- if callable(attr):
- setattr(cls, attr_name, skip_if_not_implemented(attr))
- return cls
- def custom_tokenizers(test_case):
- """
- Decorator marking a test for a custom tokenizer.
- Custom tokenizers require additional dependencies, and are skipped by default. Set the RUN_CUSTOM_TOKENIZERS
- environment variable to a truthy value to run them.
- """
- return unittest.skipUnless(_run_custom_tokenizers, "test of custom tokenizers")(test_case)
- def require_bs4(test_case):
- """
- Decorator marking a test that requires BeautifulSoup4. These tests are skipped when BeautifulSoup4 isn't installed.
- """
- return unittest.skipUnless(is_bs4_available(), "test requires BeautifulSoup4")(test_case)
- def require_galore_torch(test_case):
- """
- Decorator marking a test that requires GaLore. These tests are skipped when GaLore isn't installed.
- https://github.com/jiaweizzhao/GaLore
- """
- return unittest.skipUnless(is_galore_torch_available(), "test requires GaLore")(test_case)
- def require_apollo_torch(test_case):
- """
- Decorator marking a test that requires GaLore. These tests are skipped when APOLLO isn't installed.
- https://github.com/zhuhanqing/APOLLO
- """
- return unittest.skipUnless(is_apollo_torch_available(), "test requires APOLLO")(test_case)
- def require_torch_optimi(test_case):
- """
- Decorator marking a test that requires torch-optimi. These tests are skipped when torch-optimi isn't installed.
- https://github.com/jxnl/torch-optimi
- """
- return unittest.skipUnless(is_torch_optimi_available(), "test requires torch-optimi")(test_case)
- def require_lomo(test_case):
- """
- Decorator marking a test that requires LOMO. These tests are skipped when LOMO-optim isn't installed.
- https://github.com/OpenLMLab/LOMO
- """
- return unittest.skipUnless(is_lomo_available(), "test requires LOMO")(test_case)
- def require_grokadamw(test_case):
- """
- Decorator marking a test that requires GrokAdamW. These tests are skipped when GrokAdamW isn't installed.
- """
- return unittest.skipUnless(is_grokadamw_available(), "test requires GrokAdamW")(test_case)
- def require_schedulefree(test_case):
- """
- Decorator marking a test that requires schedulefree. These tests are skipped when schedulefree isn't installed.
- https://github.com/facebookresearch/schedule_free
- """
- return unittest.skipUnless(is_schedulefree_available(), "test requires schedulefree")(test_case)
- def require_cv2(test_case):
- """
- Decorator marking a test that requires OpenCV.
- These tests are skipped when OpenCV isn't installed.
- """
- return unittest.skipUnless(is_cv2_available(), "test requires OpenCV")(test_case)
- def require_levenshtein(test_case):
- """
- Decorator marking a test that requires Levenshtein.
- These tests are skipped when Levenshtein isn't installed.
- """
- return unittest.skipUnless(is_levenshtein_available(), "test requires Levenshtein")(test_case)
- def require_nltk(test_case):
- """
- Decorator marking a test that requires NLTK.
- These tests are skipped when NLTK isn't installed.
- """
- return unittest.skipUnless(is_nltk_available(), "test requires NLTK")(test_case)
- def require_accelerate(test_case, min_version: str = ACCELERATE_MIN_VERSION):
- """
- Decorator marking a test that requires accelerate. These tests are skipped when accelerate isn't installed.
- """
- return unittest.skipUnless(
- is_accelerate_available(min_version), f"test requires accelerate version >= {min_version}"
- )(test_case)
- def require_triton(min_version: str = TRITON_MIN_VERSION):
- """
- Decorator marking a test that requires triton. These tests are skipped when triton isn't installed.
- """
- def decorator(test_case):
- return unittest.skipUnless(is_triton_available(min_version), f"test requires triton version >= {min_version}")(
- test_case
- )
- return decorator
- def require_gguf(test_case, min_version: str = GGUF_MIN_VERSION):
- """
- Decorator marking a test that requires ggguf. These tests are skipped when gguf isn't installed.
- """
- return unittest.skipUnless(is_gguf_available(min_version), f"test requires gguf version >= {min_version}")(
- test_case
- )
- def require_fsdp(test_case, min_version: str = "1.12.0"):
- """
- Decorator marking a test that requires fsdp. These tests are skipped when fsdp isn't installed.
- """
- return unittest.skipUnless(is_fsdp_available(min_version), f"test requires torch version >= {min_version}")(
- test_case
- )
- def require_g2p_en(test_case):
- """
- Decorator marking a test that requires g2p_en. These tests are skipped when SentencePiece isn't installed.
- """
- return unittest.skipUnless(is_g2p_en_available(), "test requires g2p_en")(test_case)
- def require_safetensors(test_case):
- """
- Decorator marking a test that requires safetensors. These tests are skipped when safetensors isn't installed.
- """
- return unittest.skipUnless(is_safetensors_available(), "test requires safetensors")(test_case)
- def require_rjieba(test_case):
- """
- Decorator marking a test that requires rjieba. These tests are skipped when rjieba isn't installed.
- """
- return unittest.skipUnless(is_rjieba_available(), "test requires rjieba")(test_case)
- def require_jinja(test_case):
- """
- Decorator marking a test that requires jinja. These tests are skipped when jinja isn't installed.
- """
- return unittest.skipUnless(is_jinja_available(), "test requires jinja")(test_case)
- def require_onnx(test_case):
- return unittest.skipUnless(is_onnx_available(), "test requires ONNX")(test_case)
- def require_timm(test_case):
- """
- Decorator marking a test that requires Timm.
- These tests are skipped when Timm isn't installed.
- """
- return unittest.skipUnless(is_timm_available(), "test requires Timm")(test_case)
- def require_natten(test_case):
- """
- Decorator marking a test that requires NATTEN.
- These tests are skipped when NATTEN isn't installed.
- """
- return unittest.skipUnless(is_natten_available(), "test requires natten")(test_case)
- def require_torch(test_case):
- """
- Decorator marking a test that requires PyTorch.
- These tests are skipped when PyTorch isn't installed.
- """
- return unittest.skipUnless(is_torch_available(), "test requires PyTorch")(test_case)
- def require_torch_greater_or_equal(version: str):
- """
- Decorator marking a test that requires PyTorch version >= `version`.
- These tests are skipped when PyTorch version is less than `version`.
- """
- def decorator(test_case):
- return unittest.skipUnless(is_torch_greater_or_equal(version), f"test requires PyTorch version >= {version}")(
- test_case
- )
- return decorator
- def require_huggingface_hub_greater_or_equal(version: str):
- """
- Decorator marking a test that requires huggingface_hub version >= `version`.
- These tests are skipped when huggingface_hub version is less than `version`.
- """
- def decorator(test_case):
- return unittest.skipUnless(
- is_huggingface_hub_greater_or_equal(version), f"test requires huggingface_hub version >= {version}"
- )(test_case)
- return decorator
- def require_flash_attn(test_case):
- """
- Decorator marking a test that requires Flash Attention.
- These tests are skipped when Flash Attention isn't installed.
- """
- flash_attn_available = is_flash_attn_2_available()
- kernels_available = is_kernels_available()
- try:
- from kernels import get_kernel
- get_kernel("kernels-community/flash-attn")
- except Exception as _:
- kernels_available = False
- return unittest.skipUnless(kernels_available | flash_attn_available, "test requires Flash Attention")(test_case)
- def require_kernels(test_case):
- """
- Decorator marking a test that requires the kernels library.
- These tests are skipped when the kernels library isn't installed.
- """
- return unittest.skipUnless(is_kernels_available(), "test requires the kernels library")(test_case)
- def require_flash_attn_3(test_case):
- """
- Decorator marking a test that requires Flash Attention 3.
- These tests are skipped when Flash Attention 3 isn't installed.
- """
- return unittest.skipUnless(is_flash_attn_3_available(), "test requires Flash Attention 3")(test_case)
- def require_read_token(test_case):
- """
- A decorator that loads the HF token for tests that require to load gated models.
- """
- token = os.getenv("HF_HUB_READ_TOKEN")
- if isinstance(test_case, type):
- for attr_name in dir(test_case):
- attr = getattr(test_case, attr_name)
- if isinstance(attr, types.FunctionType):
- if getattr(attr, "__require_read_token__", False):
- continue
- wrapped = require_read_token(attr)
- setattr(test_case, attr_name, wrapped)
- return test_case
- else:
- if getattr(test_case, "__require_read_token__", False):
- return test_case
- @functools.wraps(test_case)
- def wrapper(*args, **kwargs):
- if token is not None:
- with patch("huggingface_hub.utils._headers.get_token", return_value=token):
- return test_case(*args, **kwargs)
- else: # Allow running locally with the default token env variable
- # dealing with static/class methods and called by `self.xxx`
- if "staticmethod" in inspect.getsource(test_case).strip():
- if len(args) > 0 and isinstance(args[0], unittest.TestCase):
- return test_case(*args[1:], **kwargs)
- return test_case(*args, **kwargs)
- wrapper.__require_read_token__ = True
- return wrapper
- def require_peft(test_case):
- """
- Decorator marking a test that requires PEFT.
- These tests are skipped when PEFT isn't installed.
- """
- return unittest.skipUnless(is_peft_available(), "test requires PEFT")(test_case)
- def require_torchvision(test_case):
- """
- Decorator marking a test that requires Torchvision.
- These tests are skipped when Torchvision isn't installed.
- """
- return unittest.skipUnless(is_torchvision_available(), "test requires Torchvision")(test_case)
- def require_torchcodec(test_case):
- """
- Decorator marking a test that requires Torchcodec.
- These tests are skipped when Torchcodec isn't installed.
- """
- return unittest.skipUnless(is_torchcodec_available(), "test requires Torchcodec")(test_case)
- def require_torch_or_tf(test_case):
- """
- Decorator marking a test that requires PyTorch or TensorFlow.
- These tests are skipped when neither PyTorch not TensorFlow is installed.
- """
- return unittest.skipUnless(is_torch_available() or is_tf_available(), "test requires PyTorch or TensorFlow")(
- test_case
- )
- def require_intel_extension_for_pytorch(test_case):
- """
- Decorator marking a test that requires Intel Extension for PyTorch.
- These tests are skipped when Intel Extension for PyTorch isn't installed or it does not match current PyTorch
- version.
- """
- return unittest.skipUnless(
- is_ipex_available(),
- "test requires Intel Extension for PyTorch to be installed and match current PyTorch version, see"
- " https://github.com/intel/intel-extension-for-pytorch",
- )(test_case)
- def require_torchaudio(test_case):
- """
- Decorator marking a test that requires torchaudio. These tests are skipped when torchaudio isn't installed.
- """
- return unittest.skipUnless(is_torchaudio_available(), "test requires torchaudio")(test_case)
- def require_sentencepiece(test_case):
- """
- Decorator marking a test that requires SentencePiece. These tests are skipped when SentencePiece isn't installed.
- """
- return unittest.skipUnless(is_sentencepiece_available(), "test requires SentencePiece")(test_case)
- def require_sacremoses(test_case):
- """
- Decorator marking a test that requires Sacremoses. These tests are skipped when Sacremoses isn't installed.
- """
- return unittest.skipUnless(is_sacremoses_available(), "test requires Sacremoses")(test_case)
- def require_seqio(test_case):
- """
- Decorator marking a test that requires SentencePiece. These tests are skipped when SentencePiece isn't installed.
- """
- return unittest.skipUnless(is_seqio_available(), "test requires Seqio")(test_case)
- def require_scipy(test_case):
- """
- Decorator marking a test that requires Scipy. These tests are skipped when SentencePiece isn't installed.
- """
- return unittest.skipUnless(is_scipy_available(), "test requires Scipy")(test_case)
- def require_tokenizers(test_case):
- """
- Decorator marking a test that requires 🤗 Tokenizers. These tests are skipped when 🤗 Tokenizers isn't installed.
- """
- return unittest.skipUnless(is_tokenizers_available(), "test requires tokenizers")(test_case)
- def require_keras_nlp(test_case):
- """
- Decorator marking a test that requires keras_nlp. These tests are skipped when keras_nlp isn't installed.
- """
- return unittest.skipUnless(is_keras_nlp_available(), "test requires keras_nlp")(test_case)
- def require_pandas(test_case):
- """
- Decorator marking a test that requires pandas. These tests are skipped when pandas isn't installed.
- """
- return unittest.skipUnless(is_pandas_available(), "test requires pandas")(test_case)
- def require_pytesseract(test_case):
- """
- Decorator marking a test that requires PyTesseract. These tests are skipped when PyTesseract isn't installed.
- """
- return unittest.skipUnless(is_pytesseract_available(), "test requires PyTesseract")(test_case)
- def require_pytorch_quantization(test_case):
- """
- Decorator marking a test that requires PyTorch Quantization Toolkit. These tests are skipped when PyTorch
- Quantization Toolkit isn't installed.
- """
- return unittest.skipUnless(is_pytorch_quantization_available(), "test requires PyTorch Quantization Toolkit")(
- test_case
- )
- def require_vision(test_case):
- """
- Decorator marking a test that requires the vision dependencies. These tests are skipped when torchaudio isn't
- installed.
- """
- return unittest.skipUnless(is_vision_available(), "test requires vision")(test_case)
- def require_ftfy(test_case):
- """
- Decorator marking a test that requires ftfy. These tests are skipped when ftfy isn't installed.
- """
- return unittest.skipUnless(is_ftfy_available(), "test requires ftfy")(test_case)
- def require_spacy(test_case):
- """
- Decorator marking a test that requires SpaCy. These tests are skipped when SpaCy isn't installed.
- """
- return unittest.skipUnless(is_spacy_available(), "test requires spacy")(test_case)
- def require_torch_multi_gpu(test_case):
- """
- Decorator marking a test that requires a multi-GPU CUDA setup (in PyTorch). These tests are skipped on a machine without
- multiple CUDA GPUs.
- To run *only* the multi_gpu tests, assuming all test names contain multi_gpu: $ pytest -sv ./tests -k "multi_gpu"
- """
- if not is_torch_available():
- return unittest.skip(reason="test requires PyTorch")(test_case)
- import torch
- return unittest.skipUnless(torch.cuda.device_count() > 1, "test requires multiple CUDA GPUs")(test_case)
- def require_torch_multi_accelerator(test_case):
- """
- Decorator marking a test that requires a multi-accelerator (in PyTorch). These tests are skipped on a machine
- without multiple accelerators. To run *only* the multi_accelerator tests, assuming all test names contain
- multi_accelerator: $ pytest -sv ./tests -k "multi_accelerator"
- """
- if not is_torch_available():
- return unittest.skip(reason="test requires PyTorch")(test_case)
- return unittest.skipUnless(backend_device_count(torch_device) > 1, "test requires multiple accelerators")(
- test_case
- )
- def require_torch_non_multi_gpu(test_case):
- """
- Decorator marking a test that requires 0 or 1 GPU setup (in PyTorch).
- """
- if not is_torch_available():
- return unittest.skip(reason="test requires PyTorch")(test_case)
- import torch
- return unittest.skipUnless(torch.cuda.device_count() < 2, "test requires 0 or 1 GPU")(test_case)
- def require_torch_non_multi_accelerator(test_case):
- """
- Decorator marking a test that requires 0 or 1 accelerator setup (in PyTorch).
- """
- if not is_torch_available():
- return unittest.skip(reason="test requires PyTorch")(test_case)
- return unittest.skipUnless(backend_device_count(torch_device) < 2, "test requires 0 or 1 accelerator")(test_case)
- def require_torch_up_to_2_gpus(test_case):
- """
- Decorator marking a test that requires 0 or 1 or 2 GPU setup (in PyTorch).
- """
- if not is_torch_available():
- return unittest.skip(reason="test requires PyTorch")(test_case)
- import torch
- return unittest.skipUnless(torch.cuda.device_count() < 3, "test requires 0 or 1 or 2 GPUs")(test_case)
- def require_torch_up_to_2_accelerators(test_case):
- """
- Decorator marking a test that requires 0 or 1 or 2 accelerator setup (in PyTorch).
- """
- if not is_torch_available():
- return unittest.skip(reason="test requires PyTorch")(test_case)
- return unittest.skipUnless(backend_device_count(torch_device) < 3, "test requires 0 or 1 or 2 accelerators")(
- test_case
- )
- def require_torch_xla(test_case):
- """
- Decorator marking a test that requires TorchXLA (in PyTorch).
- """
- return unittest.skipUnless(is_torch_xla_available(), "test requires TorchXLA")(test_case)
- def require_torch_neuroncore(test_case):
- """
- Decorator marking a test that requires NeuronCore (in PyTorch).
- """
- return unittest.skipUnless(is_torch_neuroncore_available(check_device=False), "test requires PyTorch NeuronCore")(
- test_case
- )
- def require_torch_npu(test_case):
- """
- Decorator marking a test that requires NPU (in PyTorch).
- """
- return unittest.skipUnless(is_torch_npu_available(), "test requires PyTorch NPU")(test_case)
- def require_torch_multi_npu(test_case):
- """
- Decorator marking a test that requires a multi-NPU setup (in PyTorch). These tests are skipped on a machine without
- multiple NPUs.
- To run *only* the multi_npu tests, assuming all test names contain multi_npu: $ pytest -sv ./tests -k "multi_npu"
- """
- if not is_torch_npu_available():
- return unittest.skip(reason="test requires PyTorch NPU")(test_case)
- return unittest.skipUnless(torch.npu.device_count() > 1, "test requires multiple NPUs")(test_case)
- def require_non_hpu(test_case):
- """
- Decorator marking a test that should be skipped for HPU.
- """
- return unittest.skipUnless(torch_device != "hpu", "test requires a non-HPU")(test_case)
- def require_torch_xpu(test_case):
- """
- Decorator marking a test that requires XPU (in PyTorch).
- These tests are skipped when XPU backend is not available. XPU backend might be available either via stock
- PyTorch (>=2.4) or via Intel Extension for PyTorch. In the latter case, if IPEX is installed, its version
- must match match current PyTorch version.
- """
- return unittest.skipUnless(is_torch_xpu_available(), "test requires XPU device")(test_case)
- def require_non_xpu(test_case):
- """
- Decorator marking a test that should be skipped for XPU.
- """
- return unittest.skipUnless(torch_device != "xpu", "test requires a non-XPU")(test_case)
- def require_torch_multi_xpu(test_case):
- """
- Decorator marking a test that requires a multi-XPU setup (in PyTorch). These tests are skipped on a machine without
- multiple XPUs.
- To run *only* the multi_xpu tests, assuming all test names contain multi_xpu: $ pytest -sv ./tests -k "multi_xpu"
- """
- if not is_torch_xpu_available():
- return unittest.skip(reason="test requires PyTorch XPU")(test_case)
- return unittest.skipUnless(torch.xpu.device_count() > 1, "test requires multiple XPUs")(test_case)
- def require_torch_multi_hpu(test_case):
- """
- Decorator marking a test that requires a multi-HPU setup (in PyTorch). These tests are skipped on a machine without
- multiple HPUs.
- To run *only* the multi_hpu tests, assuming all test names contain multi_hpu: $ pytest -sv ./tests -k "multi_hpu"
- """
- if not is_torch_hpu_available():
- return unittest.skip(reason="test requires PyTorch HPU")(test_case)
- return unittest.skipUnless(torch.hpu.device_count() > 1, "test requires multiple HPUs")(test_case)
- if is_torch_available():
- # Set env var CUDA_VISIBLE_DEVICES="" to force cpu-mode
- import torch
- if "TRANSFORMERS_TEST_BACKEND" in os.environ:
- backend = os.environ["TRANSFORMERS_TEST_BACKEND"]
- try:
- _ = importlib.import_module(backend)
- except ModuleNotFoundError as e:
- raise ModuleNotFoundError(
- f"Failed to import `TRANSFORMERS_TEST_BACKEND` '{backend}'! This should be the name of an installed module. The original error (look up to see its"
- f" traceback):\n{e}"
- ) from e
- if "TRANSFORMERS_TEST_DEVICE" in os.environ:
- torch_device = os.environ["TRANSFORMERS_TEST_DEVICE"]
- if torch_device == "cuda" and not torch.cuda.is_available():
- raise ValueError(
- f"TRANSFORMERS_TEST_DEVICE={torch_device}, but CUDA is unavailable. Please double-check your testing environment."
- )
- if torch_device == "xpu" and not is_torch_xpu_available():
- raise ValueError(
- f"TRANSFORMERS_TEST_DEVICE={torch_device}, but XPU is unavailable. Please double-check your testing environment."
- )
- if torch_device == "npu" and not is_torch_npu_available():
- raise ValueError(
- f"TRANSFORMERS_TEST_DEVICE={torch_device}, but NPU is unavailable. Please double-check your testing environment."
- )
- if torch_device == "mlu" and not is_torch_mlu_available():
- raise ValueError(
- f"TRANSFORMERS_TEST_DEVICE={torch_device}, but MLU is unavailable. Please double-check your testing environment."
- )
- if torch_device == "hpu" and not is_torch_hpu_available():
- raise ValueError(
- f"TRANSFORMERS_TEST_DEVICE={torch_device}, but HPU is unavailable. Please double-check your testing environment."
- )
- try:
- # try creating device to see if provided device is valid
- _ = torch.device(torch_device)
- except RuntimeError as e:
- raise RuntimeError(
- f"Unknown testing device specified by environment variable `TRANSFORMERS_TEST_DEVICE`: {torch_device}"
- ) from e
- elif torch.cuda.is_available():
- torch_device = "cuda"
- elif is_torch_npu_available():
- torch_device = "npu"
- elif is_torch_mlu_available():
- torch_device = "mlu"
- elif is_torch_hpu_available():
- torch_device = "hpu"
- elif is_torch_xpu_available():
- torch_device = "xpu"
- else:
- torch_device = "cpu"
- else:
- torch_device = None
- if is_tf_available():
- import tensorflow as tf
- if is_flax_available():
- import jax
- jax_device = jax.default_backend()
- else:
- jax_device = None
- def require_torchdynamo(test_case):
- """Decorator marking a test that requires TorchDynamo"""
- return unittest.skipUnless(is_torchdynamo_available(), "test requires TorchDynamo")(test_case)
- def require_torchao(test_case):
- """Decorator marking a test that requires torchao"""
- return unittest.skipUnless(is_torchao_available(), "test requires torchao")(test_case)
- def require_torchao_version_greater_or_equal(torchao_version):
- def decorator(test_case):
- correct_torchao_version = is_torchao_available() and version.parse(
- version.parse(importlib.metadata.version("torchao")).base_version
- ) >= version.parse(torchao_version)
- return unittest.skipUnless(
- correct_torchao_version, f"Test requires torchao with the version greater than {torchao_version}."
- )(test_case)
- return decorator
- def require_torch_tensorrt_fx(test_case):
- """Decorator marking a test that requires Torch-TensorRT FX"""
- return unittest.skipUnless(is_torch_tensorrt_fx_available(), "test requires Torch-TensorRT FX")(test_case)
- def require_torch_gpu(test_case):
- """Decorator marking a test that requires CUDA and PyTorch."""
- return unittest.skipUnless(torch_device == "cuda", "test requires CUDA")(test_case)
- def require_torch_mps(test_case):
- """Decorator marking a test that requires CUDA and PyTorch."""
- return unittest.skipUnless(torch_device == "mps", "test requires MPS")(test_case)
- def require_large_cpu_ram(test_case, memory: float = 80):
- """Decorator marking a test that requires a CPU RAM with more than `memory` GiB of memory."""
- if not is_psutil_available():
- return test_case
- import psutil
- return unittest.skipUnless(
- psutil.virtual_memory().total / 1024**3 > memory,
- f"test requires a machine with more than {memory} GiB of CPU RAM memory",
- )(test_case)
- def require_torch_large_gpu(test_case, memory: float = 20):
- """Decorator marking a test that requires a CUDA GPU with more than `memory` GiB of memory."""
- if torch_device != "cuda":
- return unittest.skip(reason=f"test requires a CUDA GPU with more than {memory} GiB of memory")(test_case)
- return unittest.skipUnless(
- torch.cuda.get_device_properties(0).total_memory / 1024**3 > memory,
- f"test requires a GPU with more than {memory} GiB of memory",
- )(test_case)
- def require_torch_large_accelerator(test_case, memory: float = 20):
- """Decorator marking a test that requires an accelerator with more than `memory` GiB of memory."""
- if torch_device != "cuda" and torch_device != "xpu":
- return unittest.skip(reason=f"test requires a GPU or XPU with more than {memory} GiB of memory")(test_case)
- torch_accelerator_module = getattr(torch, torch_device)
- return unittest.skipUnless(
- torch_accelerator_module.get_device_properties(0).total_memory / 1024**3 > memory,
- f"test requires a GPU or XPU with more than {memory} GiB of memory",
- )(test_case)
- def require_torch_gpu_if_bnb_not_multi_backend_enabled(test_case):
- """
- Decorator marking a test that requires a GPU if bitsandbytes multi-backend feature is not enabled.
- """
- if is_bitsandbytes_available() and is_bitsandbytes_multi_backend_available():
- return test_case
- return require_torch_gpu(test_case)
- def require_torch_accelerator(test_case):
- """Decorator marking a test that requires an accessible accelerator and PyTorch."""
- return unittest.skipUnless(torch_device is not None and torch_device != "cpu", "test requires accelerator")(
- test_case
- )
- def require_torch_fp16(test_case):
- """Decorator marking a test that requires a device that supports fp16"""
- return unittest.skipUnless(
- is_torch_fp16_available_on_device(torch_device), "test requires device with fp16 support"
- )(test_case)
- def require_fp8(test_case):
- """Decorator marking a test that requires supports for fp8"""
- return unittest.skipUnless(is_accelerate_available() and is_fp8_available(), "test requires fp8 support")(
- test_case
- )
- def require_torch_bf16(test_case):
- """Decorator marking a test that requires a device that supports bf16"""
- return unittest.skipUnless(
- is_torch_bf16_available_on_device(torch_device), "test requires device with bf16 support"
- )(test_case)
- def require_torch_bf16_gpu(test_case):
- """Decorator marking a test that requires torch>=1.10, using Ampere GPU or newer arch with cuda>=11.0"""
- return unittest.skipUnless(
- is_torch_bf16_gpu_available(),
- "test requires torch>=1.10, using Ampere GPU or newer arch with cuda>=11.0",
- )(test_case)
- def require_deterministic_for_xpu(test_case):
- @wraps(test_case)
- def wrapper(*args, **kwargs):
- if is_torch_xpu_available():
- original_state = torch.are_deterministic_algorithms_enabled()
- try:
- torch.use_deterministic_algorithms(True)
- return test_case(*args, **kwargs)
- finally:
- torch.use_deterministic_algorithms(original_state)
- else:
- return test_case(*args, **kwargs)
- return wrapper
- def require_torch_tf32(test_case):
- """Decorator marking a test that requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7."""
- return unittest.skipUnless(
- is_torch_tf32_available(), "test requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7"
- )(test_case)
- def require_detectron2(test_case):
- """Decorator marking a test that requires detectron2."""
- return unittest.skipUnless(is_detectron2_available(), "test requires `detectron2`")(test_case)
- def require_faiss(test_case):
- """Decorator marking a test that requires faiss."""
- return unittest.skipUnless(is_faiss_available(), "test requires `faiss`")(test_case)
- def require_optuna(test_case):
- """
- Decorator marking a test that requires optuna.
- These tests are skipped when optuna isn't installed.
- """
- return unittest.skipUnless(is_optuna_available(), "test requires optuna")(test_case)
- def require_ray(test_case):
- """
- Decorator marking a test that requires Ray/tune.
- These tests are skipped when Ray/tune isn't installed.
- """
- return unittest.skipUnless(is_ray_available(), "test requires Ray/tune")(test_case)
- def require_sigopt(test_case):
- """
- Decorator marking a test that requires SigOpt.
- These tests are skipped when SigOpt isn't installed.
- """
- return unittest.skipUnless(is_sigopt_available(), "test requires SigOpt")(test_case)
- def require_swanlab(test_case):
- """
- Decorator marking a test that requires swanlab.
- These tests are skipped when swanlab isn't installed.
- """
- return unittest.skipUnless(is_swanlab_available(), "test requires swanlab")(test_case)
- def require_trackio(test_case):
- """
- Decorator marking a test that requires trackio.
- These tests are skipped when trackio isn't installed.
- """
- return unittest.skipUnless(is_trackio_available(), "test requires trackio")(test_case)
- def require_wandb(test_case):
- """
- Decorator marking a test that requires wandb.
- These tests are skipped when wandb isn't installed.
- """
- return unittest.skipUnless(is_wandb_available(), "test requires wandb")(test_case)
- def require_clearml(test_case):
- """
- Decorator marking a test requires clearml.
- These tests are skipped when clearml isn't installed.
- """
- return unittest.skipUnless(is_clearml_available(), "test requires clearml")(test_case)
- def require_deepspeed(test_case):
- """
- Decorator marking a test that requires deepspeed
- """
- return unittest.skipUnless(is_deepspeed_available(), "test requires deepspeed")(test_case)
- def require_apex(test_case):
- """
- Decorator marking a test that requires apex
- """
- return unittest.skipUnless(is_apex_available(), "test requires apex")(test_case)
- def require_aqlm(test_case):
- """
- Decorator marking a test that requires aqlm
- """
- return unittest.skipUnless(is_aqlm_available(), "test requires aqlm")(test_case)
- def require_vptq(test_case):
- """
- Decorator marking a test that requires vptq
- """
- return unittest.skipUnless(is_vptq_available(), "test requires vptq")(test_case)
- def require_spqr(test_case):
- """
- Decorator marking a test that requires spqr
- """
- return unittest.skipUnless(is_spqr_available(), "test requires spqr")(test_case)
- def require_eetq(test_case):
- """
- Decorator marking a test that requires eetq
- """
- eetq_available = is_eetq_available()
- if eetq_available:
- try:
- import eetq # noqa: F401
- except ImportError as exc:
- if "shard_checkpoint" in str(exc):
- # EETQ 1.0.0 is currently broken with the latest transformers because it tries to import the removed
- # shard_checkpoint function, see https://github.com/NetEase-FuXi/EETQ/issues/34.
- # TODO: Remove once eetq releases a fix and this release is used in CI
- eetq_available = False
- return unittest.skipUnless(eetq_available, "test requires eetq")(test_case)
- def require_av(test_case):
- """
- Decorator marking a test that requires av
- """
- return unittest.skipUnless(is_av_available(), "test requires av")(test_case)
- def require_decord(test_case):
- """
- Decorator marking a test that requires decord
- """
- return unittest.skipUnless(is_decord_available(), "test requires decord")(test_case)
- def require_bitsandbytes(test_case):
- """
- Decorator marking a test that requires the bitsandbytes library. Will be skipped when the library or its hard dependency torch is not installed.
- """
- if is_bitsandbytes_available() and is_torch_available():
- try:
- import pytest
- return pytest.mark.bitsandbytes(test_case)
- except ImportError:
- return test_case
- else:
- return unittest.skip(reason="test requires bitsandbytes and torch")(test_case)
- def require_optimum(test_case):
- """
- Decorator for optimum dependency
- """
- return unittest.skipUnless(is_optimum_available(), "test requires optimum")(test_case)
- def require_tensorboard(test_case):
- """
- Decorator for `tensorboard` dependency
- """
- return unittest.skipUnless(is_tensorboard_available(), "test requires tensorboard")
- def require_gptq(test_case):
- """
- Decorator for auto_gptq dependency
- """
- return unittest.skipUnless(
- is_gptqmodel_available() or is_auto_gptq_available(), "test requires gptqmodel or auto-gptq"
- )(test_case)
- def require_hqq(test_case):
- """
- Decorator for hqq dependency
- """
- return unittest.skipUnless(is_hqq_available(), "test requires hqq")(test_case)
- def require_auto_awq(test_case):
- """
- Decorator for auto_awq dependency
- """
- return unittest.skipUnless(is_auto_awq_available(), "test requires autoawq")(test_case)
- def require_auto_round(test_case):
- """
- Decorator for auto_round dependency
- """
- return unittest.skipUnless(is_auto_round_available(), "test requires autoround")(test_case)
- def require_optimum_quanto(test_case):
- """
- Decorator for quanto dependency
- """
- return unittest.skipUnless(is_optimum_quanto_available(), "test requires optimum-quanto")(test_case)
- def require_compressed_tensors(test_case):
- """
- Decorator for compressed_tensors dependency
- """
- return unittest.skipUnless(is_compressed_tensors_available(), "test requires compressed_tensors")(test_case)
- def require_fbgemm_gpu(test_case):
- """
- Decorator for fbgemm_gpu dependency
- """
- return unittest.skipUnless(is_fbgemm_gpu_available(), "test requires fbgemm-gpu")(test_case)
- def require_quark(test_case):
- """
- Decorator for quark dependency
- """
- return unittest.skipUnless(is_quark_available(), "test requires quark")(test_case)
- def require_flute_hadamard(test_case):
- """
- Decorator marking a test that requires higgs and hadamard
- """
- return unittest.skipUnless(
- is_flute_available() and is_hadamard_available(), "test requires flute and fast_hadamard_transform"
- )(test_case)
- def require_fp_quant(test_case):
- """
- Decorator marking a test that requires fp_quant and qutlass
- """
- return unittest.skipUnless(is_fp_quant_available(), "test requires fp_quant")(test_case)
- def require_qutlass(test_case):
- """
- Decorator marking a test that requires qutlass
- """
- return unittest.skipUnless(is_qutlass_available(), "test requires qutlass")(test_case)
- def require_phonemizer(test_case):
- """
- Decorator marking a test that requires phonemizer
- """
- return unittest.skipUnless(is_phonemizer_available(), "test requires phonemizer")(test_case)
- def require_pyctcdecode(test_case):
- """
- Decorator marking a test that requires pyctcdecode
- """
- return unittest.skipUnless(is_pyctcdecode_available(), "test requires pyctcdecode")(test_case)
- def require_librosa(test_case):
- """
- Decorator marking a test that requires librosa
- """
- return unittest.skipUnless(is_librosa_available(), "test requires librosa")(test_case)
- def require_liger_kernel(test_case):
- """
- Decorator marking a test that requires liger_kernel
- """
- return unittest.skipUnless(is_liger_kernel_available(), "test requires liger_kernel")(test_case)
- def require_essentia(test_case):
- """
- Decorator marking a test that requires essentia
- """
- return unittest.skipUnless(is_essentia_available(), "test requires essentia")(test_case)
- def require_pretty_midi(test_case):
- """
- Decorator marking a test that requires pretty_midi
- """
- return unittest.skipUnless(is_pretty_midi_available(), "test requires pretty_midi")(test_case)
- def cmd_exists(cmd):
- return shutil.which(cmd) is not None
- def require_usr_bin_time(test_case):
- """
- Decorator marking a test that requires `/usr/bin/time`
- """
- return unittest.skipUnless(cmd_exists("/usr/bin/time"), "test requires /usr/bin/time")(test_case)
- def require_sudachi(test_case):
- """
- Decorator marking a test that requires sudachi
- """
- return unittest.skipUnless(is_sudachi_available(), "test requires sudachi")(test_case)
- def require_sudachi_projection(test_case):
- """
- Decorator marking a test that requires sudachi_projection
- """
- return unittest.skipUnless(is_sudachi_projection_available(), "test requires sudachi which supports projection")(
- test_case
- )
- def require_jumanpp(test_case):
- """
- Decorator marking a test that requires jumanpp
- """
- return unittest.skipUnless(is_jumanpp_available(), "test requires jumanpp")(test_case)
- def require_cython(test_case):
- """
- Decorator marking a test that requires jumanpp
- """
- return unittest.skipUnless(is_cython_available(), "test requires cython")(test_case)
- def require_tiktoken(test_case):
- """
- Decorator marking a test that requires TikToken. These tests are skipped when TikToken isn't installed.
- """
- return unittest.skipUnless(is_tiktoken_available(), "test requires TikToken")(test_case)
- def require_speech(test_case):
- """
- Decorator marking a test that requires speech. These tests are skipped when speech isn't available.
- """
- return unittest.skipUnless(is_speech_available(), "test requires torchaudio")(test_case)
- def require_openai(test_case):
- """
- Decorator marking a test that requires openai
- """
- return unittest.skipUnless(is_openai_available(), "test requires openai")(test_case)
- def require_mistral_common(test_case):
- """
- Decorator marking a test that requires mistral-common. These tests are skipped when mistral-common isn't available.
- """
- return unittest.skipUnless(is_mistral_common_available(), "test requires mistral-common")(test_case)
- def get_gpu_count():
- """
- Return the number of available gpus (regardless of whether torch, tf or jax is used)
- """
- if is_torch_available():
- import torch
- return torch.cuda.device_count()
- elif is_tf_available():
- import tensorflow as tf
- return len(tf.config.list_physical_devices("GPU"))
- elif is_flax_available():
- import jax
- return jax.device_count()
- else:
- return 0
- def get_tests_dir(append_path=None):
- """
- Args:
- append_path: optional path to append to the tests dir path
- Return:
- The full path to the `tests` dir, so that the tests can be invoked from anywhere. Optionally `append_path` is
- joined after the `tests` dir the former is provided.
- """
- # this function caller's __file__
- caller__file__ = inspect.stack()[1][1]
- tests_dir = os.path.abspath(os.path.dirname(caller__file__))
- while not tests_dir.endswith("tests"):
- tests_dir = os.path.dirname(tests_dir)
- if append_path:
- return os.path.join(tests_dir, append_path)
- else:
- return tests_dir
- def get_steps_per_epoch(trainer: Trainer) -> int:
- training_args = trainer.args
- train_dataloader = trainer.get_train_dataloader()
- initial_training_values = trainer.set_initial_training_values(
- args=training_args,
- dataloader=train_dataloader,
- total_train_batch_size=training_args.per_device_train_batch_size,
- )
- steps_per_epoch = initial_training_values[1]
- return steps_per_epoch
- def evaluate_side_effect_factory(
- side_effect_values: list[dict[str, float]],
- ) -> Generator[dict[str, float], None, None]:
- """
- Function that returns side effects for the _evaluate method.
- Used when we're unsure of exactly how many times _evaluate will be called.
- """
- yield from side_effect_values
- while True:
- yield side_effect_values[-1]
- #
- # Helper functions for dealing with testing text outputs
- # The original code came from:
- # https://github.com/fastai/fastai/blob/master/tests/utils/text.py
- # When any function contains print() calls that get overwritten, like progress bars,
- # a special care needs to be applied, since under pytest -s captured output (capsys
- # or contextlib.redirect_stdout) contains any temporary printed strings, followed by
- # \r's. This helper function ensures that the buffer will contain the same output
- # with and without -s in pytest, by turning:
- # foo bar\r tar mar\r final message
- # into:
- # final message
- # it can handle a single string or a multiline buffer
- def apply_print_resets(buf):
- return re.sub(r"^.*\r", "", buf, 0, re.MULTILINE)
- def assert_screenout(out, what):
- out_pr = apply_print_resets(out).lower()
- match_str = out_pr.find(what.lower())
- assert match_str != -1, f"expecting to find {what} in output: f{out_pr}"
- def set_config_for_less_flaky_test(config):
- target_attrs = [
- "rms_norm_eps",
- "layer_norm_eps",
- "norm_eps",
- "norm_epsilon",
- "layer_norm_epsilon",
- "batch_norm_eps",
- ]
- for target_attr in target_attrs:
- setattr(config, target_attr, 1.0)
- # norm layers (layer/group norm, etc.) could cause flaky tests when the tensors have very small variance.
- # (We don't need the original epsilon values to check eager/sdpa matches)
- attrs = ["text_config", "vision_config", "text_encoder", "audio_encoder", "decoder"]
- for attr in attrs:
- if hasattr(config, attr):
- for target_attr in target_attrs:
- setattr(getattr(config, attr), target_attr, 1.0)
- def set_model_for_less_flaky_test(model):
- # Another way to make sure norm layers have desired epsilon. (Some models don't set it from its config.)
- target_names = (
- "LayerNorm",
- "GroupNorm",
- "BatchNorm",
- "RMSNorm",
- "BatchNorm2d",
- "BatchNorm1d",
- "BitGroupNormActivation",
- "WeightStandardizedConv2d",
- )
- target_attrs = ["eps", "epsilon", "variance_epsilon"]
- if is_torch_available() and isinstance(model, torch.nn.Module):
- for module in model.modules():
- if type(module).__name__.endswith(target_names):
- for attr in target_attrs:
- if hasattr(module, attr):
- setattr(module, attr, 1.0)
- class CaptureStd:
- """
- Context manager to capture:
- - stdout: replay it, clean it up and make it available via `obj.out`
- - stderr: replay it and make it available via `obj.err`
- Args:
- out (`bool`, *optional*, defaults to `True`): Whether to capture stdout or not.
- err (`bool`, *optional*, defaults to `True`): Whether to capture stderr or not.
- replay (`bool`, *optional*, defaults to `True`): Whether to replay or not.
- By default each captured stream gets replayed back on context's exit, so that one can see what the test was
- doing. If this is a not wanted behavior and the captured data shouldn't be replayed, pass `replay=False` to
- disable this feature.
- Examples:
- ```python
- # to capture stdout only with auto-replay
- with CaptureStdout() as cs:
- print("Secret message")
- assert "message" in cs.out
- # to capture stderr only with auto-replay
- import sys
- with CaptureStderr() as cs:
- print("Warning: ", file=sys.stderr)
- assert "Warning" in cs.err
- # to capture both streams with auto-replay
- with CaptureStd() as cs:
- print("Secret message")
- print("Warning: ", file=sys.stderr)
- assert "message" in cs.out
- assert "Warning" in cs.err
- # to capture just one of the streams, and not the other, with auto-replay
- with CaptureStd(err=False) as cs:
- print("Secret message")
- assert "message" in cs.out
- # but best use the stream-specific subclasses
- # to capture without auto-replay
- with CaptureStd(replay=False) as cs:
- print("Secret message")
- assert "message" in cs.out
- ```"""
- def __init__(self, out=True, err=True, replay=True):
- self.replay = replay
- if out:
- self.out_buf = StringIO()
- self.out = "error: CaptureStd context is unfinished yet, called too early"
- else:
- self.out_buf = None
- self.out = "not capturing stdout"
- if err:
- self.err_buf = StringIO()
- self.err = "error: CaptureStd context is unfinished yet, called too early"
- else:
- self.err_buf = None
- self.err = "not capturing stderr"
- def __enter__(self):
- if self.out_buf:
- self.out_old = sys.stdout
- sys.stdout = self.out_buf
- if self.err_buf:
- self.err_old = sys.stderr
- sys.stderr = self.err_buf
- return self
- def __exit__(self, *exc):
- if self.out_buf:
- sys.stdout = self.out_old
- captured = self.out_buf.getvalue()
- if self.replay:
- sys.stdout.write(captured)
- self.out = apply_print_resets(captured)
- if self.err_buf:
- sys.stderr = self.err_old
- captured = self.err_buf.getvalue()
- if self.replay:
- sys.stderr.write(captured)
- self.err = captured
- def __repr__(self):
- msg = ""
- if self.out_buf:
- msg += f"stdout: {self.out}\n"
- if self.err_buf:
- msg += f"stderr: {self.err}\n"
- return msg
- # in tests it's the best to capture only the stream that's wanted, otherwise
- # it's easy to miss things, so unless you need to capture both streams, use the
- # subclasses below (less typing). Or alternatively, configure `CaptureStd` to
- # disable the stream you don't need to test.
- class CaptureStdout(CaptureStd):
- """Same as CaptureStd but captures only stdout"""
- def __init__(self, replay=True):
- super().__init__(err=False, replay=replay)
- class CaptureStderr(CaptureStd):
- """Same as CaptureStd but captures only stderr"""
- def __init__(self, replay=True):
- super().__init__(out=False, replay=replay)
- class CaptureLogger:
- """
- Context manager to capture `logging` streams
- Args:
- logger: 'logging` logger object
- Returns:
- The captured output is available via `self.out`
- Example:
- ```python
- >>> from transformers import logging
- >>> from transformers.testing_utils import CaptureLogger
- >>> msg = "Testing 1, 2, 3"
- >>> logging.set_verbosity_info()
- >>> logger = logging.get_logger("transformers.models.bart.tokenization_bart")
- >>> with CaptureLogger(logger) as cl:
- ... logger.info(msg)
- >>> assert cl.out, msg + "\n"
- ```
- """
- def __init__(self, logger):
- self.logger = logger
- self.io = StringIO()
- self.sh = logging.StreamHandler(self.io)
- self.out = ""
- def __enter__(self):
- self.logger.addHandler(self.sh)
- return self
- def __exit__(self, *exc):
- self.logger.removeHandler(self.sh)
- self.out = self.io.getvalue()
- def __repr__(self):
- return f"captured: {self.out}\n"
- @contextlib.contextmanager
- def LoggingLevel(level):
- """
- This is a context manager to temporarily change transformers modules logging level to the desired value and have it
- restored to the original setting at the end of the scope.
- Example:
- ```python
- with LoggingLevel(logging.INFO):
- AutoModel.from_pretrained("openai-community/gpt2") # calls logger.info() several times
- ```
- """
- orig_level = transformers_logging.get_verbosity()
- try:
- transformers_logging.set_verbosity(level)
- yield
- finally:
- transformers_logging.set_verbosity(orig_level)
- class TemporaryHubRepo:
- """Create a temporary Hub repository and return its `RepoUrl` object. This is similar to
- `tempfile.TemporaryDirectory` and can be used as a context manager. For example:
- with TemporaryHubRepo(token=self._token) as temp_repo:
- ...
- Upon exiting the context, the repository and everything contained in it are removed.
- Example:
- ```python
- with TemporaryHubRepo(token=self._token) as temp_repo:
- model.push_to_hub(tmp_repo.repo_id, token=self._token)
- ```
- """
- def __init__(self, namespace: Optional[str] = None, token: Optional[str] = None) -> None:
- self.token = token
- with tempfile.TemporaryDirectory() as tmp_dir:
- repo_id = Path(tmp_dir).name
- if namespace is not None:
- repo_id = f"{namespace}/{repo_id}"
- self.repo_url = huggingface_hub.create_repo(repo_id, token=self.token)
- def __enter__(self):
- return self.repo_url
- def __exit__(self, exc, value, tb):
- delete_repo(repo_id=self.repo_url.repo_id, token=self.token, missing_ok=True)
- @contextlib.contextmanager
- # adapted from https://stackoverflow.com/a/64789046/9201239
- def ExtendSysPath(path: Union[str, os.PathLike]) -> Iterator[None]:
- """
- Temporary add given path to `sys.path`.
- Usage :
- ```python
- with ExtendSysPath("/path/to/dir"):
- mymodule = importlib.import_module("mymodule")
- ```
- """
- path = os.fspath(path)
- try:
- sys.path.insert(0, path)
- yield
- finally:
- sys.path.remove(path)
- class TestCasePlus(unittest.TestCase):
- """
- This class extends *unittest.TestCase* with additional features.
- Feature 1: A set of fully resolved important file and dir path accessors.
- In tests often we need to know where things are relative to the current test file, and it's not trivial since the
- test could be invoked from more than one directory or could reside in sub-directories with different depths. This
- class solves this problem by sorting out all the basic paths and provides easy accessors to them:
- - `pathlib` objects (all fully resolved):
- - `test_file_path` - the current test file path (=`__file__`)
- - `test_file_dir` - the directory containing the current test file
- - `tests_dir` - the directory of the `tests` test suite
- - `examples_dir` - the directory of the `examples` test suite
- - `repo_root_dir` - the directory of the repository
- - `src_dir` - the directory of `src` (i.e. where the `transformers` sub-dir resides)
- - stringified paths---same as above but these return paths as strings, rather than `pathlib` objects:
- - `test_file_path_str`
- - `test_file_dir_str`
- - `tests_dir_str`
- - `examples_dir_str`
- - `repo_root_dir_str`
- - `src_dir_str`
- Feature 2: Flexible auto-removable temporary dirs which are guaranteed to get removed at the end of test.
- 1. Create a unique temporary dir:
- ```python
- def test_whatever(self):
- tmp_dir = self.get_auto_remove_tmp_dir()
- ```
- `tmp_dir` will contain the path to the created temporary dir. It will be automatically removed at the end of the
- test.
- 2. Create a temporary dir of my choice, ensure it's empty before the test starts and don't
- empty it after the test.
- ```python
- def test_whatever(self):
- tmp_dir = self.get_auto_remove_tmp_dir("./xxx")
- ```
- This is useful for debug when you want to monitor a specific directory and want to make sure the previous tests
- didn't leave any data in there.
- 3. You can override the first two options by directly overriding the `before` and `after` args, leading to the
- following behavior:
- `before=True`: the temporary dir will always be cleared at the beginning of the test.
- `before=False`: if the temporary dir already existed, any existing files will remain there.
- `after=True`: the temporary dir will always be deleted at the end of the test.
- `after=False`: the temporary dir will always be left intact at the end of the test.
- Note 1: In order to run the equivalent of `rm -r` safely, only subdirs of the project repository checkout are
- allowed if an explicit `tmp_dir` is used, so that by mistake no `/tmp` or similar important part of the filesystem
- will get nuked. i.e. please always pass paths that start with `./`
- Note 2: Each test can register multiple temporary dirs and they all will get auto-removed, unless requested
- otherwise.
- Feature 3: Get a copy of the `os.environ` object that sets up `PYTHONPATH` specific to the current test suite. This
- is useful for invoking external programs from the test suite - e.g. distributed training.
- ```python
- def test_whatever(self):
- env = self.get_env()
- ```"""
- def setUp(self):
- # get_auto_remove_tmp_dir feature:
- self.teardown_tmp_dirs = []
- # figure out the resolved paths for repo_root, tests, examples, etc.
- self._test_file_path = inspect.getfile(self.__class__)
- path = Path(self._test_file_path).resolve()
- self._test_file_dir = path.parents[0]
- for up in [1, 2, 3]:
- tmp_dir = path.parents[up]
- if (tmp_dir / "src").is_dir() and (tmp_dir / "tests").is_dir():
- break
- if tmp_dir:
- self._repo_root_dir = tmp_dir
- else:
- raise ValueError(f"can't figure out the root of the repo from {self._test_file_path}")
- self._tests_dir = self._repo_root_dir / "tests"
- self._examples_dir = self._repo_root_dir / "examples"
- self._src_dir = self._repo_root_dir / "src"
- @property
- def test_file_path(self):
- return self._test_file_path
- @property
- def test_file_path_str(self):
- return str(self._test_file_path)
- @property
- def test_file_dir(self):
- return self._test_file_dir
- @property
- def test_file_dir_str(self):
- return str(self._test_file_dir)
- @property
- def tests_dir(self):
- return self._tests_dir
- @property
- def tests_dir_str(self):
- return str(self._tests_dir)
- @property
- def examples_dir(self):
- return self._examples_dir
- @property
- def examples_dir_str(self):
- return str(self._examples_dir)
- @property
- def repo_root_dir(self):
- return self._repo_root_dir
- @property
- def repo_root_dir_str(self):
- return str(self._repo_root_dir)
- @property
- def src_dir(self):
- return self._src_dir
- @property
- def src_dir_str(self):
- return str(self._src_dir)
- def get_env(self):
- """
- Return a copy of the `os.environ` object that sets up `PYTHONPATH` correctly, depending on the test suite it's
- invoked from. This is useful for invoking external programs from the test suite - e.g. distributed training.
- It always inserts `./src` first, then `./tests` or `./examples` depending on the test suite type and finally
- the preset `PYTHONPATH` if any (all full resolved paths).
- """
- env = os.environ.copy()
- paths = [self.repo_root_dir_str, self.src_dir_str]
- if "/examples" in self.test_file_dir_str:
- paths.append(self.examples_dir_str)
- else:
- paths.append(self.tests_dir_str)
- paths.append(env.get("PYTHONPATH", ""))
- env["PYTHONPATH"] = ":".join(paths)
- return env
- def get_auto_remove_tmp_dir(self, tmp_dir=None, before=None, after=None):
- """
- Args:
- tmp_dir (`string`, *optional*):
- if `None`:
- - a unique temporary path will be created
- - sets `before=True` if `before` is `None`
- - sets `after=True` if `after` is `None`
- else:
- - `tmp_dir` will be created
- - sets `before=True` if `before` is `None`
- - sets `after=False` if `after` is `None`
- before (`bool`, *optional*):
- If `True` and the `tmp_dir` already exists, make sure to empty it right away if `False` and the
- `tmp_dir` already exists, any existing files will remain there.
- after (`bool`, *optional*):
- If `True`, delete the `tmp_dir` at the end of the test if `False`, leave the `tmp_dir` and its contents
- intact at the end of the test.
- Returns:
- tmp_dir(`string`): either the same value as passed via *tmp_dir* or the path to the auto-selected tmp dir
- """
- if tmp_dir is not None:
- # defining the most likely desired behavior for when a custom path is provided.
- # this most likely indicates the debug mode where we want an easily locatable dir that:
- # 1. gets cleared out before the test (if it already exists)
- # 2. is left intact after the test
- if before is None:
- before = True
- if after is None:
- after = False
- # using provided path
- path = Path(tmp_dir).resolve()
- # to avoid nuking parts of the filesystem, only relative paths are allowed
- if not tmp_dir.startswith("./"):
- raise ValueError(
- f"`tmp_dir` can only be a relative path, i.e. `./some/path`, but received `{tmp_dir}`"
- )
- # ensure the dir is empty to start with
- if before is True and path.exists():
- shutil.rmtree(tmp_dir, ignore_errors=True)
- path.mkdir(parents=True, exist_ok=True)
- else:
- # defining the most likely desired behavior for when a unique tmp path is auto generated
- # (not a debug mode), here we require a unique tmp dir that:
- # 1. is empty before the test (it will be empty in this situation anyway)
- # 2. gets fully removed after the test
- if before is None:
- before = True
- if after is None:
- after = True
- # using unique tmp dir (always empty, regardless of `before`)
- tmp_dir = tempfile.mkdtemp()
- if after is True:
- # register for deletion
- self.teardown_tmp_dirs.append(tmp_dir)
- return tmp_dir
- def python_one_liner_max_rss(self, one_liner_str):
- """
- Runs the passed python one liner (just the code) and returns how much max cpu memory was used to run the
- program.
- Args:
- one_liner_str (`string`):
- a python one liner code that gets passed to `python -c`
- Returns:
- max cpu memory bytes used to run the program. This value is likely to vary slightly from run to run.
- Requirements:
- this helper needs `/usr/bin/time` to be installed (`apt install time`)
- Example:
- ```
- one_liner_str = 'from transformers import AutoModel; AutoModel.from_pretrained("google-t5/t5-large")'
- max_rss = self.python_one_liner_max_rss(one_liner_str)
- ```
- """
- if not cmd_exists("/usr/bin/time"):
- raise ValueError("/usr/bin/time is required, install with `apt install time`")
- cmd = shlex.split(f"/usr/bin/time -f %M python -c '{one_liner_str}'")
- with CaptureStd() as cs:
- execute_subprocess_async(cmd, env=self.get_env())
- # returned data is in KB so convert to bytes
- max_rss = int(cs.err.split("\n")[-2].replace("stderr: ", "")) * 1024
- return max_rss
- def tearDown(self):
- # get_auto_remove_tmp_dir feature: remove registered temp dirs
- for path in self.teardown_tmp_dirs:
- shutil.rmtree(path, ignore_errors=True)
- self.teardown_tmp_dirs = []
- if is_accelerate_available():
- AcceleratorState._reset_state()
- PartialState._reset_state()
- # delete all the env variables having `ACCELERATE` in them
- for k in list(os.environ.keys()):
- if "ACCELERATE" in k:
- del os.environ[k]
- def mockenv(**kwargs):
- """
- this is a convenience wrapper, that allows this ::
- @mockenv(RUN_SLOW=True, USE_TF=False) def test_something():
- run_slow = os.getenv("RUN_SLOW", False) use_tf = os.getenv("USE_TF", False)
- """
- return mock.patch.dict(os.environ, kwargs)
- # from https://stackoverflow.com/a/34333710/9201239
- @contextlib.contextmanager
- def mockenv_context(*remove, **update):
- """
- Temporarily updates the `os.environ` dictionary in-place. Similar to mockenv
- The `os.environ` dictionary is updated in-place so that the modification is sure to work in all situations.
- Args:
- remove: Environment variables to remove.
- update: Dictionary of environment variables and values to add/update.
- """
- env = os.environ
- update = update or {}
- remove = remove or []
- # List of environment variables being updated or removed.
- stomped = (set(update.keys()) | set(remove)) & set(env.keys())
- # Environment variables and values to restore on exit.
- update_after = {k: env[k] for k in stomped}
- # Environment variables and values to remove on exit.
- remove_after = frozenset(k for k in update if k not in env)
- try:
- env.update(update)
- [env.pop(k, None) for k in remove]
- yield
- finally:
- env.update(update_after)
- [env.pop(k) for k in remove_after]
- # --- pytest conf functions --- #
- # to avoid multiple invocation from tests/conftest.py and examples/conftest.py - make sure it's called only once
- pytest_opt_registered = {}
- def pytest_addoption_shared(parser):
- """
- This function is to be called from `conftest.py` via `pytest_addoption` wrapper that has to be defined there.
- It allows loading both `conftest.py` files at once without causing a failure due to adding the same `pytest`
- option.
- """
- option = "--make-reports"
- if option not in pytest_opt_registered:
- parser.addoption(
- option,
- action="store",
- default=False,
- help="generate report files. The value of this option is used as a prefix to report names",
- )
- pytest_opt_registered[option] = 1
- def pytest_terminal_summary_main(tr, id):
- """
- Generate multiple reports at the end of test suite run - each report goes into a dedicated file in the current
- directory. The report files are prefixed with the test suite name.
- This function emulates --duration and -rA pytest arguments.
- This function is to be called from `conftest.py` via `pytest_terminal_summary` wrapper that has to be defined
- there.
- Args:
- - tr: `terminalreporter` passed from `conftest.py`
- - id: unique id like `tests` or `examples` that will be incorporated into the final reports filenames - this is
- needed as some jobs have multiple runs of pytest, so we can't have them overwrite each other.
- NB: this functions taps into a private _pytest API and while unlikely, it could break should pytest do internal
- changes - also it calls default internal methods of terminalreporter which can be hijacked by various `pytest-`
- plugins and interfere.
- """
- from _pytest.config import create_terminal_writer
- if not len(id):
- id = "tests"
- config = tr.config
- orig_writer = config.get_terminal_writer()
- orig_tbstyle = config.option.tbstyle
- orig_reportchars = tr.reportchars
- dir = f"reports/{id}"
- Path(dir).mkdir(parents=True, exist_ok=True)
- report_files = {
- k: f"{dir}/{k}.txt"
- for k in [
- "durations",
- "errors",
- "failures_long",
- "failures_short",
- "failures_line",
- "passes",
- "stats",
- "summary_short",
- "warnings",
- ]
- }
- # custom durations report
- # note: there is no need to call pytest --durations=XX to get this separate report
- # adapted from https://github.com/pytest-dev/pytest/blob/897f151e/src/_pytest/runner.py#L66
- dlist = []
- for replist in tr.stats.values():
- for rep in replist:
- if hasattr(rep, "duration"):
- dlist.append(rep)
- if dlist:
- dlist.sort(key=lambda x: x.duration, reverse=True)
- with open(report_files["durations"], "w") as f:
- durations_min = 0.05 # sec
- f.write("slowest durations\n")
- for i, rep in enumerate(dlist):
- if rep.duration < durations_min:
- f.write(f"{len(dlist) - i} durations < {durations_min} secs were omitted")
- break
- f.write(f"{rep.duration:02.2f}s {rep.when:<8} {rep.nodeid}\n")
- def summary_failures_short(tr):
- # expecting that the reports were --tb=long (default) so we chop them off here to the last frame
- reports = tr.getreports("failed")
- if not reports:
- return
- tr.write_sep("=", "FAILURES SHORT STACK")
- for rep in reports:
- msg = tr._getfailureheadline(rep)
- tr.write_sep("_", msg, red=True, bold=True)
- # chop off the optional leading extra frames, leaving only the last one
- longrepr = re.sub(r".*_ _ _ (_ ){10,}_ _ ", "", rep.longreprtext, 0, re.MULTILINE | re.DOTALL)
- tr._tw.line(longrepr)
- # note: not printing out any rep.sections to keep the report short
- # use ready-made report funcs, we are just hijacking the filehandle to log to a dedicated file each
- # adapted from https://github.com/pytest-dev/pytest/blob/897f151e/src/_pytest/terminal.py#L814
- # note: some pytest plugins may interfere by hijacking the default `terminalreporter` (e.g.
- # pytest-instafail does that)
- # report failures with line/short/long styles
- config.option.tbstyle = "auto" # full tb
- with open(report_files["failures_long"], "w") as f:
- tr._tw = create_terminal_writer(config, f)
- tr.summary_failures()
- # config.option.tbstyle = "short" # short tb
- with open(report_files["failures_short"], "w") as f:
- tr._tw = create_terminal_writer(config, f)
- summary_failures_short(tr)
- config.option.tbstyle = "line" # one line per error
- with open(report_files["failures_line"], "w") as f:
- tr._tw = create_terminal_writer(config, f)
- tr.summary_failures()
- with open(report_files["errors"], "w") as f:
- tr._tw = create_terminal_writer(config, f)
- tr.summary_errors()
- with open(report_files["warnings"], "w") as f:
- tr._tw = create_terminal_writer(config, f)
- tr.summary_warnings() # normal warnings
- tr.summary_warnings() # final warnings
- tr.reportchars = "wPpsxXEf" # emulate -rA (used in summary_passes() and short_test_summary())
- # Skip the `passes` report, as it starts to take more than 5 minutes, and sometimes it timeouts on CircleCI if it
- # takes > 10 minutes (as this part doesn't generate any output on the terminal).
- # (also, it seems there is no useful information in this report, and we rarely need to read it)
- # with open(report_files["passes"], "w") as f:
- # tr._tw = create_terminal_writer(config, f)
- # tr.summary_passes()
- with open(report_files["summary_short"], "w") as f:
- tr._tw = create_terminal_writer(config, f)
- tr.short_test_summary()
- with open(report_files["stats"], "w") as f:
- tr._tw = create_terminal_writer(config, f)
- tr.summary_stats()
- # restore:
- tr._tw = orig_writer
- tr.reportchars = orig_reportchars
- config.option.tbstyle = orig_tbstyle
- # --- distributed testing functions --- #
- # adapted from https://stackoverflow.com/a/59041913/9201239
- import asyncio # noqa
- class _RunOutput:
- def __init__(self, returncode, stdout, stderr):
- self.returncode = returncode
- self.stdout = stdout
- self.stderr = stderr
- async def _read_stream(stream, callback):
- while True:
- line = await stream.readline()
- if line:
- callback(line)
- else:
- break
- async def _stream_subprocess(cmd, env=None, stdin=None, timeout=None, quiet=False, echo=False) -> _RunOutput:
- if echo:
- print("\nRunning: ", " ".join(cmd))
- p = await asyncio.create_subprocess_exec(
- cmd[0],
- *cmd[1:],
- stdin=stdin,
- stdout=asyncio.subprocess.PIPE,
- stderr=asyncio.subprocess.PIPE,
- env=env,
- )
- # note: there is a warning for a possible deadlock when using `wait` with huge amounts of data in the pipe
- # https://docs.python.org/3/library/asyncio-subprocess.html#asyncio.asyncio.subprocess.Process.wait
- #
- # If it starts hanging, will need to switch to the following code. The problem is that no data
- # will be seen until it's done and if it hangs for example there will be no debug info.
- # out, err = await p.communicate()
- # return _RunOutput(p.returncode, out, err)
- out = []
- err = []
- def tee(line, sink, pipe, label=""):
- line = line.decode("utf-8").rstrip()
- sink.append(line)
- if not quiet:
- print(label, line, file=pipe)
- # XXX: the timeout doesn't seem to make any difference here
- await asyncio.wait(
- [
- asyncio.create_task(_read_stream(p.stdout, lambda l: tee(l, out, sys.stdout, label="stdout:"))),
- asyncio.create_task(_read_stream(p.stderr, lambda l: tee(l, err, sys.stderr, label="stderr:"))),
- ],
- timeout=timeout,
- )
- return _RunOutput(await p.wait(), out, err)
- def execute_subprocess_async(cmd, env=None, stdin=None, timeout=180, quiet=False, echo=True) -> _RunOutput:
- loop = asyncio.get_event_loop()
- result = loop.run_until_complete(
- _stream_subprocess(cmd, env=env, stdin=stdin, timeout=timeout, quiet=quiet, echo=echo)
- )
- cmd_str = " ".join(cmd)
- if result.returncode > 0:
- stderr = "\n".join(result.stderr)
- raise RuntimeError(
- f"'{cmd_str}' failed with returncode {result.returncode}\n\n"
- f"The combined stderr from workers follows:\n{stderr}"
- )
- # check that the subprocess actually did run and produced some output, should the test rely on
- # the remote side to do the testing
- if not result.stdout and not result.stderr:
- raise RuntimeError(f"'{cmd_str}' produced no output.")
- return result
- def pytest_xdist_worker_id():
- """
- Returns an int value of worker's numerical id under `pytest-xdist`'s concurrent workers `pytest -n N` regime, or 0
- if `-n 1` or `pytest-xdist` isn't being used.
- """
- worker = os.environ.get("PYTEST_XDIST_WORKER", "gw0")
- worker = re.sub(r"^gw", "", worker, 0, re.MULTILINE)
- return int(worker)
- def get_torch_dist_unique_port():
- """
- Returns a port number that can be fed to `torch.distributed.launch`'s `--master_port` argument.
- Under `pytest-xdist` it adds a delta number based on a worker id so that concurrent tests don't try to use the same
- port at once.
- """
- port = 29500
- uniq_delta = pytest_xdist_worker_id()
- return port + uniq_delta
- def nested_simplify(obj, decimals=3):
- """
- Simplifies an object by rounding float numbers, and downcasting tensors/numpy arrays to get simple equality test
- within tests.
- """
- import numpy as np
- if isinstance(obj, list):
- return [nested_simplify(item, decimals) for item in obj]
- if isinstance(obj, tuple):
- return tuple(nested_simplify(item, decimals) for item in obj)
- elif isinstance(obj, np.ndarray):
- return nested_simplify(obj.tolist())
- elif isinstance(obj, Mapping):
- return {nested_simplify(k, decimals): nested_simplify(v, decimals) for k, v in obj.items()}
- elif isinstance(obj, (str, int, np.int64)) or obj is None:
- return obj
- elif is_torch_available() and isinstance(obj, torch.Tensor):
- return nested_simplify(obj.tolist(), decimals)
- elif is_tf_available() and tf.is_tensor(obj):
- return nested_simplify(obj.numpy().tolist())
- elif isinstance(obj, float):
- return round(obj, decimals)
- elif isinstance(obj, (np.int32, np.float32, np.float16)):
- return nested_simplify(obj.item(), decimals)
- else:
- raise Exception(f"Not supported: {type(obj)}")
- def check_json_file_has_correct_format(file_path):
- with open(file_path) as f:
- lines = f.readlines()
- if len(lines) == 1:
- # length can only be 1 if dict is empty
- assert lines[0] == "{}"
- else:
- # otherwise make sure json has correct format (at least 3 lines)
- assert len(lines) >= 3
- # each key one line, ident should be 2, min length is 3
- assert lines[0].strip() == "{"
- for line in lines[1:-1]:
- left_indent = len(lines[1]) - len(lines[1].lstrip())
- assert left_indent == 2
- assert lines[-1].strip() == "}"
- def to_2tuple(x):
- if isinstance(x, collections.abc.Iterable):
- return x
- return (x, x)
- # These utils relate to ensuring the right error message is received when running scripts
- class SubprocessCallException(Exception):
- pass
- def run_command(command: list[str], return_stdout=False):
- """
- Runs `command` with `subprocess.check_output` and will potentially return the `stdout`. Will also properly capture
- if an error occurred while running `command`
- """
- try:
- output = subprocess.check_output(command, stderr=subprocess.STDOUT)
- if return_stdout:
- if hasattr(output, "decode"):
- output = output.decode("utf-8")
- return output
- except subprocess.CalledProcessError as e:
- raise SubprocessCallException(
- f"Command `{' '.join(command)}` failed with the following error:\n\n{e.output.decode()}"
- ) from e
- class RequestCounter:
- """
- Helper class that will count all requests made online.
- Might not be robust if urllib3 changes its logging format but should be good enough for us.
- Usage:
- ```py
- with RequestCounter() as counter:
- _ = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
- assert counter["GET"] == 0
- assert counter["HEAD"] == 1
- assert counter.total_calls == 1
- ```
- """
- def __enter__(self):
- self._counter = defaultdict(int)
- self._thread_id = threading.get_ident()
- self._extra_info = []
- def patched_with_thread_info(func):
- def wrap(*args, **kwargs):
- self._extra_info.append(threading.get_ident())
- return func(*args, **kwargs)
- return wrap
- self.patcher = patch.object(
- urllib3.connectionpool.log, "debug", side_effect=patched_with_thread_info(urllib3.connectionpool.log.debug)
- )
- self.mock = self.patcher.start()
- return self
- def __exit__(self, *args, **kwargs) -> None:
- assert len(self.mock.call_args_list) == len(self._extra_info)
- for thread_id, call in zip(self._extra_info, self.mock.call_args_list):
- if thread_id != self._thread_id:
- continue
- # code 307: the URL being requested by the user has moved to a temporary location
- if call.args[-2] == 307:
- continue
- log = call.args[0] % call.args[1:]
- for method in ("HEAD", "GET", "POST", "PUT", "DELETE", "CONNECT", "OPTIONS", "TRACE", "PATCH"):
- if method in log:
- self._counter[method] += 1
- break
- self.patcher.stop()
- def __getitem__(self, key: str) -> int:
- return self._counter[key]
- @property
- def total_calls(self) -> int:
- return sum(self._counter.values())
- def is_flaky(max_attempts: int = 5, wait_before_retry: Optional[float] = None, description: Optional[str] = None):
- """
- To decorate flaky tests. They will be retried on failures.
- Please note that our push tests use `pytest-rerunfailures`, which prompts the CI to rerun certain types of
- failed tests. More specifically, if the test exception contains any substring in `FLAKY_TEST_FAILURE_PATTERNS`
- (in `.circleci/create_circleci_config.py`), it will be rerun. If you find a recurrent pattern of failures,
- expand `FLAKY_TEST_FAILURE_PATTERNS` in our CI configuration instead of using `is_flaky`.
- Args:
- max_attempts (`int`, *optional*, defaults to 5):
- The maximum number of attempts to retry the flaky test.
- wait_before_retry (`float`, *optional*):
- If provided, will wait that number of seconds before retrying the test.
- description (`str`, *optional*):
- A string to describe the situation (what / where / why is flaky, link to GH issue/PR comments, errors,
- etc.)
- """
- def decorator(test_func_ref):
- @functools.wraps(test_func_ref)
- def wrapper(*args, **kwargs):
- retry_count = 1
- while retry_count < max_attempts:
- try:
- return test_func_ref(*args, **kwargs)
- except Exception as err:
- logger.error(f"Test failed with {err} at try {retry_count}/{max_attempts}.")
- if wait_before_retry is not None:
- time.sleep(wait_before_retry)
- retry_count += 1
- return test_func_ref(*args, **kwargs)
- return unittest.skipUnless(_run_flaky_tests, "test is flaky")(wrapper)
- return decorator
- def hub_retry(max_attempts: int = 5, wait_before_retry: Optional[float] = 2):
- """
- To decorate tests that download from the Hub. They can fail due to a
- variety of network issues such as timeouts, connection resets, etc.
- Args:
- max_attempts (`int`, *optional*, defaults to 5):
- The maximum number of attempts to retry the flaky test.
- wait_before_retry (`float`, *optional*, defaults to 2):
- If provided, will wait that number of seconds before retrying the test.
- """
- def decorator(test_func_ref):
- @functools.wraps(test_func_ref)
- def wrapper(*args, **kwargs):
- retry_count = 1
- while retry_count < max_attempts:
- try:
- return test_func_ref(*args, **kwargs)
- # We catch all exceptions related to network issues from requests
- except (
- requests.exceptions.ConnectionError,
- requests.exceptions.Timeout,
- requests.exceptions.ReadTimeout,
- requests.exceptions.HTTPError,
- requests.exceptions.RequestException,
- ) as err:
- logger.error(
- f"Test failed with {err} at try {retry_count}/{max_attempts} as it couldn't connect to the specified Hub repository."
- )
- if wait_before_retry is not None:
- time.sleep(wait_before_retry)
- retry_count += 1
- return test_func_ref(*args, **kwargs)
- return wrapper
- return decorator
- def run_first(test_case):
- """
- Decorator marking a test with order(1). When pytest-order plugin is installed, tests marked with this decorator
- are guaranteed to run first.
- This is especially useful in some test settings like on a Gaudi instance where a Gaudi device can only be used by a
- single process at a time. So we make sure all tests that run in a subprocess are launched first, to avoid device
- allocation conflicts.
- """
- import pytest
- return pytest.mark.order(1)(test_case)
- def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=None):
- """
- To run a test in a subprocess. In particular, this can avoid (GPU) memory issue.
- Args:
- test_case (`unittest.TestCase`):
- The test that will run `target_func`.
- target_func (`Callable`):
- The function implementing the actual testing logic.
- inputs (`dict`, *optional*, defaults to `None`):
- The inputs that will be passed to `target_func` through an (input) queue.
- timeout (`int`, *optional*, defaults to `None`):
- The timeout (in seconds) that will be passed to the input and output queues. If not specified, the env.
- variable `PYTEST_TIMEOUT` will be checked. If still `None`, its value will be set to `600`.
- """
- if timeout is None:
- timeout = int(os.environ.get("PYTEST_TIMEOUT", "600"))
- start_methohd = "spawn"
- ctx = multiprocessing.get_context(start_methohd)
- input_queue = ctx.Queue(1)
- output_queue = ctx.JoinableQueue(1)
- # We can't send `unittest.TestCase` to the child, otherwise we get issues regarding pickle.
- input_queue.put(inputs, timeout=timeout)
- process = ctx.Process(target=target_func, args=(input_queue, output_queue, timeout))
- process.start()
- # Kill the child process if we can't get outputs from it in time: otherwise, the hanging subprocess prevents
- # the test to exit properly.
- try:
- results = output_queue.get(timeout=timeout)
- output_queue.task_done()
- except Exception as e:
- process.terminate()
- test_case.fail(e)
- process.join(timeout=timeout)
- if results["error"] is not None:
- test_case.fail(f"{results['error']}")
- def run_test_using_subprocess(func):
- """
- To decorate a test to run in a subprocess using the `subprocess` module. This could avoid potential GPU memory
- issues (GPU OOM or a test that causes many subsequential failing with `CUDA error: device-side assert triggered`).
- """
- import pytest
- @functools.wraps(func)
- def wrapper(*args, **kwargs):
- if os.getenv("_INSIDE_SUB_PROCESS", None) == "1":
- func(*args, **kwargs)
- else:
- test = " ".join(os.environ.get("PYTEST_CURRENT_TEST").split(" ")[:-1])
- try:
- env = copy.deepcopy(os.environ)
- env["_INSIDE_SUB_PROCESS"] = "1"
- # This prevents the entries in `short test summary info` given by the subprocess being truncated. so the
- # full information can be passed to the parent pytest process.
- # See: https://docs.pytest.org/en/stable/explanation/ci.html
- env["CI"] = "true"
- # If not subclass of `unitTest.TestCase` and `pytestconfig` is used: try to grab and use the arguments
- if "pytestconfig" in kwargs:
- command = list(kwargs["pytestconfig"].invocation_params.args)
- for idx, x in enumerate(command):
- if x in kwargs["pytestconfig"].args:
- test = test.split("::")[1:]
- command[idx] = "::".join([f"{func.__globals__['__file__']}"] + test)
- command = [f"{sys.executable}", "-m", "pytest"] + command
- command = [x for x in command if x != "--no-summary"]
- # Otherwise, simply run the test with no option at all
- else:
- command = [f"{sys.executable}", "-m", "pytest", f"{test}"]
- subprocess.run(command, env=env, check=True, capture_output=True)
- except subprocess.CalledProcessError as e:
- exception_message = e.stdout.decode()
- lines = exception_message.split("\n")
- # Add a first line with more informative information instead of just `= test session starts =`.
- # This makes the `short test summary info` section more useful.
- if "= test session starts =" in lines[0]:
- text = ""
- for line in lines[1:]:
- if line.startswith("FAILED "):
- text = line[len("FAILED ") :]
- text = "".join(text.split(" - ")[1:])
- elif line.startswith("=") and line.endswith("=") and " failed in " in line:
- break
- elif len(text) > 0:
- text += f"\n{line}"
- text = "(subprocess) " + text
- lines = [text] + lines
- exception_message = "\n".join(lines)
- raise pytest.fail(exception_message, pytrace=False)
- return wrapper
- """
- The following contains utils to run the documentation tests without having to overwrite any files.
- The `preprocess_string` function adds `# doctest: +IGNORE_RESULT` markers on the fly anywhere a `load_dataset` call is
- made as a print would otherwise fail the corresponding line.
- To skip cuda tests, make sure to call `SKIP_CUDA_DOCTEST=1 pytest --doctest-modules <path_to_files_to_test>
- """
- def preprocess_string(string, skip_cuda_tests):
- """Prepare a docstring or a `.md` file to be run by doctest.
- The argument `string` would be the whole file content if it is a `.md` file. For a python file, it would be one of
- its docstring. In each case, it may contain multiple python code examples. If `skip_cuda_tests` is `True` and a
- cuda stuff is detective (with a heuristic), this method will return an empty string so no doctest will be run for
- `string`.
- """
- codeblock_pattern = r"(```(?:python|py)\s*\n\s*>>> )(.*?```)"
- codeblocks = re.split(codeblock_pattern, string, flags=re.DOTALL)
- is_cuda_found = False
- for i, codeblock in enumerate(codeblocks):
- if "load_dataset(" in codeblock and "# doctest: +IGNORE_RESULT" not in codeblock:
- codeblocks[i] = re.sub(r"(>>> .*load_dataset\(.*)", r"\1 # doctest: +IGNORE_RESULT", codeblock)
- if (
- (">>>" in codeblock or "..." in codeblock)
- and re.search(r"cuda|to\(0\)|device=0", codeblock)
- and skip_cuda_tests
- ):
- is_cuda_found = True
- break
- modified_string = ""
- if not is_cuda_found:
- modified_string = "".join(codeblocks)
- return modified_string
- class HfDocTestParser(doctest.DocTestParser):
- """
- Overwrites the DocTestParser from doctest to properly parse the codeblocks that are formatted with black. This
- means that there are no extra lines at the end of our snippets. The `# doctest: +IGNORE_RESULT` marker is also
- added anywhere a `load_dataset` call is made as a print would otherwise fail the corresponding line.
- Tests involving cuda are skipped base on a naive pattern that should be updated if it is not enough.
- """
- # This regular expression is used to find doctest examples in a
- # string. It defines three groups: `source` is the source code
- # (including leading indentation and prompts); `indent` is the
- # indentation of the first (PS1) line of the source code; and
- # `want` is the expected output (including leading indentation).
- # fmt: off
- _EXAMPLE_RE = re.compile(r'''
- # Source consists of a PS1 line followed by zero or more PS2 lines.
- (?P<source>
- (?:^(?P<indent> [ ]*) >>> .*) # PS1 line
- (?:\n [ ]* \.\.\. .*)*) # PS2 lines
- \n?
- # Want consists of any non-blank lines that do not start with PS1.
- (?P<want> (?:(?![ ]*$) # Not a blank line
- (?![ ]*>>>) # Not a line starting with PS1
- # !!!!!!!!!!! HF Specific !!!!!!!!!!!
- (?:(?!```).)* # Match any character except '`' until a '```' is found (this is specific to HF because black removes the last line)
- # !!!!!!!!!!! HF Specific !!!!!!!!!!!
- (?:\n|$) # Match a new line or end of string
- )*)
- ''', re.MULTILINE | re.VERBOSE
- )
- # fmt: on
- # !!!!!!!!!!! HF Specific !!!!!!!!!!!
- skip_cuda_tests: bool = bool(os.environ.get("SKIP_CUDA_DOCTEST", "0"))
- # !!!!!!!!!!! HF Specific !!!!!!!!!!!
- def parse(self, string, name="<string>"):
- """
- Overwrites the `parse` method to incorporate a skip for CUDA tests, and remove logs and dataset prints before
- calling `super().parse`
- """
- string = preprocess_string(string, self.skip_cuda_tests)
- return super().parse(string, name)
- class HfDoctestModule(Module):
- """
- Overwrites the `DoctestModule` of the pytest package to make sure the HFDocTestParser is used when discovering
- tests.
- """
- def collect(self) -> Iterable[DoctestItem]:
- class MockAwareDocTestFinder(doctest.DocTestFinder):
- """A hackish doctest finder that overrides stdlib internals to fix a stdlib bug.
- https://github.com/pytest-dev/pytest/issues/3456 https://bugs.python.org/issue25532
- """
- def _find_lineno(self, obj, source_lines):
- """Doctest code does not take into account `@property`, this
- is a hackish way to fix it. https://bugs.python.org/issue17446
- Wrapped Doctests will need to be unwrapped so the correct line number is returned. This will be
- reported upstream. #8796
- """
- if isinstance(obj, property):
- obj = getattr(obj, "fget", obj)
- if hasattr(obj, "__wrapped__"):
- # Get the main obj in case of it being wrapped
- obj = inspect.unwrap(obj)
- # Type ignored because this is a private function.
- return super()._find_lineno( # type:ignore[misc]
- obj,
- source_lines,
- )
- def _find(self, tests, obj, name, module, source_lines, globs, seen) -> None:
- if _is_mocked(obj):
- return
- with _patch_unwrap_mock_aware():
- # Type ignored because this is a private function.
- super()._find( # type:ignore[misc]
- tests, obj, name, module, source_lines, globs, seen
- )
- if self.path.name == "conftest.py":
- module = self.config.pluginmanager._importconftest(
- self.path,
- self.config.getoption("importmode"),
- rootpath=self.config.rootpath,
- )
- else:
- try:
- module = import_path(
- self.path,
- root=self.config.rootpath,
- mode=self.config.getoption("importmode"),
- )
- except ImportError:
- if self.config.getvalue("doctest_ignore_import_errors"):
- skip("unable to import module %r" % self.path)
- else:
- raise
- # !!!!!!!!!!! HF Specific !!!!!!!!!!!
- finder = MockAwareDocTestFinder(parser=HfDocTestParser())
- # !!!!!!!!!!! HF Specific !!!!!!!!!!!
- optionflags = get_optionflags(self)
- runner = _get_runner(
- verbose=False,
- optionflags=optionflags,
- checker=_get_checker(),
- continue_on_failure=_get_continue_on_failure(self.config),
- )
- for test in finder.find(module, module.__name__):
- if test.examples: # skip empty doctests and cuda
- yield DoctestItem.from_parent(self, name=test.name, runner=runner, dtest=test)
- def _device_agnostic_dispatch(device: str, dispatch_table: dict[str, Callable], *args, **kwargs):
- if device not in dispatch_table:
- if not callable(dispatch_table["default"]):
- return dispatch_table["default"]
- return dispatch_table["default"](*args, **kwargs)
- fn = dispatch_table[device]
- # Some device agnostic functions return values or None, will return then directly.
- if not callable(fn):
- return fn
- return fn(*args, **kwargs)
- if is_torch_available():
- # Mappings from device names to callable functions to support device agnostic
- # testing.
- BACKEND_MANUAL_SEED = {
- "cuda": torch.cuda.manual_seed,
- "cpu": torch.manual_seed,
- "default": torch.manual_seed,
- }
- BACKEND_EMPTY_CACHE = {
- "cuda": torch.cuda.empty_cache,
- "cpu": None,
- "default": None,
- }
- BACKEND_DEVICE_COUNT = {
- "cuda": torch.cuda.device_count,
- "cpu": lambda: 0,
- "default": lambda: 1,
- }
- BACKEND_RESET_MAX_MEMORY_ALLOCATED = {
- "cuda": torch.cuda.reset_max_memory_allocated,
- "cpu": None,
- "default": None,
- }
- BACKEND_MAX_MEMORY_ALLOCATED = {
- "cuda": torch.cuda.max_memory_allocated,
- "cpu": 0,
- "default": 0,
- }
- BACKEND_RESET_PEAK_MEMORY_STATS = {
- "cuda": torch.cuda.reset_peak_memory_stats,
- "cpu": None,
- "default": None,
- }
- BACKEND_MEMORY_ALLOCATED = {
- "cuda": torch.cuda.memory_allocated,
- "cpu": 0,
- "default": 0,
- }
- BACKEND_SYNCHRONIZE = {
- "cuda": torch.cuda.synchronize,
- "cpu": None,
- "default": None,
- }
- BACKEND_TORCH_ACCELERATOR_MODULE = {
- "cuda": torch.cuda,
- "cpu": None,
- "default": None,
- }
- else:
- BACKEND_MANUAL_SEED = {"default": None}
- BACKEND_EMPTY_CACHE = {"default": None}
- BACKEND_DEVICE_COUNT = {"default": lambda: 0}
- BACKEND_RESET_MAX_MEMORY_ALLOCATED = {"default": None}
- BACKEND_RESET_PEAK_MEMORY_STATS = {"default": None}
- BACKEND_MAX_MEMORY_ALLOCATED = {"default": 0}
- BACKEND_MEMORY_ALLOCATED = {"default": 0}
- BACKEND_SYNCHRONIZE = {"default": None}
- BACKEND_TORCH_ACCELERATOR_MODULE = {"default": None}
- if is_torch_hpu_available():
- BACKEND_MANUAL_SEED["hpu"] = torch.hpu.manual_seed
- BACKEND_DEVICE_COUNT["hpu"] = torch.hpu.device_count
- BACKEND_TORCH_ACCELERATOR_MODULE["hpu"] = torch.hpu
- if is_torch_mlu_available():
- BACKEND_EMPTY_CACHE["mlu"] = torch.mlu.empty_cache
- BACKEND_MANUAL_SEED["mlu"] = torch.mlu.manual_seed
- BACKEND_DEVICE_COUNT["mlu"] = torch.mlu.device_count
- BACKEND_TORCH_ACCELERATOR_MODULE["mlu"] = torch.mlu
- if is_torch_npu_available():
- BACKEND_EMPTY_CACHE["npu"] = torch.npu.empty_cache
- BACKEND_MANUAL_SEED["npu"] = torch.npu.manual_seed
- BACKEND_DEVICE_COUNT["npu"] = torch.npu.device_count
- BACKEND_TORCH_ACCELERATOR_MODULE["npu"] = torch.npu
- if is_torch_xpu_available():
- BACKEND_EMPTY_CACHE["xpu"] = torch.xpu.empty_cache
- BACKEND_MANUAL_SEED["xpu"] = torch.xpu.manual_seed
- BACKEND_DEVICE_COUNT["xpu"] = torch.xpu.device_count
- BACKEND_RESET_MAX_MEMORY_ALLOCATED["xpu"] = torch.xpu.reset_peak_memory_stats
- BACKEND_RESET_PEAK_MEMORY_STATS["xpu"] = torch.xpu.reset_peak_memory_stats
- BACKEND_MAX_MEMORY_ALLOCATED["xpu"] = torch.xpu.max_memory_allocated
- BACKEND_MEMORY_ALLOCATED["xpu"] = torch.xpu.memory_allocated
- BACKEND_SYNCHRONIZE["xpu"] = torch.xpu.synchronize
- BACKEND_TORCH_ACCELERATOR_MODULE["xpu"] = torch.xpu
- if is_torch_xla_available():
- BACKEND_EMPTY_CACHE["xla"] = torch.cuda.empty_cache
- BACKEND_MANUAL_SEED["xla"] = torch.cuda.manual_seed
- BACKEND_DEVICE_COUNT["xla"] = torch.cuda.device_count
- def backend_manual_seed(device: str, seed: int):
- return _device_agnostic_dispatch(device, BACKEND_MANUAL_SEED, seed)
- def backend_empty_cache(device: str):
- return _device_agnostic_dispatch(device, BACKEND_EMPTY_CACHE)
- def backend_device_count(device: str):
- return _device_agnostic_dispatch(device, BACKEND_DEVICE_COUNT)
- def backend_reset_max_memory_allocated(device: str):
- return _device_agnostic_dispatch(device, BACKEND_RESET_MAX_MEMORY_ALLOCATED)
- def backend_reset_peak_memory_stats(device: str):
- return _device_agnostic_dispatch(device, BACKEND_RESET_PEAK_MEMORY_STATS)
- def backend_max_memory_allocated(device: str):
- return _device_agnostic_dispatch(device, BACKEND_MAX_MEMORY_ALLOCATED)
- def backend_memory_allocated(device: str):
- return _device_agnostic_dispatch(device, BACKEND_MEMORY_ALLOCATED)
- def backend_synchronize(device: str):
- return _device_agnostic_dispatch(device, BACKEND_SYNCHRONIZE)
- def backend_torch_accelerator_module(device: str):
- return _device_agnostic_dispatch(device, BACKEND_TORCH_ACCELERATOR_MODULE)
- if is_torch_available():
- # If `TRANSFORMERS_TEST_DEVICE_SPEC` is enabled we need to import extra entries
- # into device to function mappings.
- if "TRANSFORMERS_TEST_DEVICE_SPEC" in os.environ:
- device_spec_path = os.environ["TRANSFORMERS_TEST_DEVICE_SPEC"]
- if not Path(device_spec_path).is_file():
- raise ValueError(
- f"Specified path to device spec file is not a file or not found. Received '{device_spec_path}"
- )
- # Try to strip extension for later import – also verifies we are importing a
- # python file.
- device_spec_dir, _ = os.path.split(os.path.realpath(device_spec_path))
- sys.path.append(device_spec_dir)
- try:
- import_name = device_spec_path[: device_spec_path.index(".py")]
- except ValueError as e:
- raise ValueError(f"Provided device spec file was not a Python file! Received '{device_spec_path}") from e
- device_spec_module = importlib.import_module(import_name)
- # Imported file must contain `DEVICE_NAME`. If it doesn't, terminate early.
- try:
- device_name = device_spec_module.DEVICE_NAME
- except AttributeError as e:
- raise AttributeError("Device spec file did not contain `DEVICE_NAME`") from e
- if "TRANSFORMERS_TEST_DEVICE" in os.environ and torch_device != device_name:
- msg = f"Mismatch between environment variable `TRANSFORMERS_TEST_DEVICE` '{torch_device}' and device found in spec '{device_name}'\n"
- msg += "Either unset `TRANSFORMERS_TEST_DEVICE` or ensure it matches device spec name."
- raise ValueError(msg)
- torch_device = device_name
- def update_mapping_from_spec(device_fn_dict: dict[str, Callable], attribute_name: str):
- try:
- # Try to import the function directly
- spec_fn = getattr(device_spec_module, attribute_name)
- device_fn_dict[torch_device] = spec_fn
- except AttributeError as e:
- # If the function doesn't exist, and there is no default, throw an error
- if "default" not in device_fn_dict:
- raise AttributeError(
- f"`{attribute_name}` not found in '{device_spec_path}' and no default fallback function found."
- ) from e
- # Add one entry here for each `BACKEND_*` dictionary.
- update_mapping_from_spec(BACKEND_MANUAL_SEED, "MANUAL_SEED_FN")
- update_mapping_from_spec(BACKEND_EMPTY_CACHE, "EMPTY_CACHE_FN")
- update_mapping_from_spec(BACKEND_DEVICE_COUNT, "DEVICE_COUNT_FN")
- def compare_pipeline_output_to_hub_spec(output, hub_spec):
- missing_keys = []
- unexpected_keys = []
- all_field_names = {field.name for field in fields(hub_spec)}
- matching_keys = sorted([key for key in output if key in all_field_names])
- # Fields with a MISSING default are required and must be in the output
- for field in fields(hub_spec):
- if field.default is MISSING and field.name not in output:
- missing_keys.append(field.name)
- # All output keys must match either a required or optional field in the Hub spec
- for output_key in output:
- if output_key not in all_field_names:
- unexpected_keys.append(output_key)
- if missing_keys or unexpected_keys:
- error = ["Pipeline output does not match Hub spec!"]
- if matching_keys:
- error.append(f"Matching keys: {matching_keys}")
- if missing_keys:
- error.append(f"Missing required keys in pipeline output: {missing_keys}")
- if unexpected_keys:
- error.append(f"Keys in pipeline output that are not in Hub spec: {unexpected_keys}")
- raise KeyError("\n".join(error))
- @require_torch
- def cleanup(device: str, gc_collect=False):
- if gc_collect:
- gc.collect()
- backend_empty_cache(device)
- torch._dynamo.reset()
- # Type definition of key used in `Expectations` class.
- DeviceProperties = tuple[Optional[str], Optional[int], Optional[int]]
- # Helper type. Makes creating instances of `Expectations` smoother.
- PackedDeviceProperties = tuple[Optional[str], Union[None, int, tuple[int, int]]]
- @cache
- def get_device_properties() -> DeviceProperties:
- """
- Get environment device properties.
- """
- if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
- import torch
- major, minor = torch.cuda.get_device_capability()
- if IS_ROCM_SYSTEM:
- return ("rocm", major, minor)
- else:
- return ("cuda", major, minor)
- elif IS_XPU_SYSTEM:
- import torch
- # To get more info of the architecture meaning and bit allocation, refer to https://github.com/intel/llvm/blob/sycl/sycl/include/sycl/ext/oneapi/experimental/device_architecture.def
- arch = torch.xpu.get_device_capability()["architecture"]
- gen_mask = 0x000000FF00000000
- gen = (arch & gen_mask) >> 32
- return ("xpu", gen, None)
- else:
- return (torch_device, None, None)
- def unpack_device_properties(
- properties: Optional[PackedDeviceProperties] = None,
- ) -> DeviceProperties:
- """
- Unpack a `PackedDeviceProperties` tuple into consistently formatted `DeviceProperties` tuple. If properties is None, it is fetched.
- """
- if properties is None:
- return get_device_properties()
- device_type, major_minor = properties
- if major_minor is None:
- major, minor = None, None
- elif isinstance(major_minor, int):
- major, minor = major_minor, None
- else:
- major, minor = major_minor
- return device_type, major, minor
- class Expectations(UserDict[PackedDeviceProperties, Any]):
- def get_expectation(self) -> Any:
- """
- Find best matching expectation based on environment device properties. We look at device_type, major and minor
- versions of the drivers. Expectations are stored as a dictionary with keys of the form
- (device_type, (major, minor)). If the major and minor versions are not provided, we use None.
- """
- return self.find_expectation(get_device_properties())
- def unpacked(self) -> list[tuple[DeviceProperties, Any]]:
- return [(unpack_device_properties(k), v) for k, v in self.data.items()]
- @staticmethod
- def is_default(expectation_key: PackedDeviceProperties) -> bool:
- """
- This function returns True if the expectation_key is the Default expectation (None, None).
- When an Expectation dict contains a Default value, it is generally because the test existed before Expectations.
- When we modify a test to use Expectations for a specific hardware, we don't want to affect the tests on other
- hardwares. Thus we set the previous value as the Default expectation with key (None, None) and add a value for
- the specific hardware with key (hardware_type, (major, minor)).
- """
- return all(p is None for p in expectation_key)
- @staticmethod
- def score(properties: DeviceProperties, other: DeviceProperties) -> float:
- """
- Returns score indicating how similar two instances of the `Properties` tuple are.
- Rules are as follows:
- * Matching `type` adds one point, semi-matching `type` adds 0.1 point (e.g. cuda and rocm).
- * If types match, matching `major` adds another point, and then matching `minor` adds another.
- * The Default expectation (None, None) is worth 0.5 point, which is better than semi-matching. More on this
- in the `is_default` function.
- """
- device_type, major, minor = properties
- other_device_type, other_major, other_minor = other
- score = 0
- # Matching device type, maybe major and minor
- if device_type is not None and device_type == other_device_type:
- score += 1
- if major is not None and major == other_major:
- score += 1
- if minor is not None and minor == other_minor:
- score += 1
- # Semi-matching device type, which carries less importance than the default expectation
- elif device_type in ["cuda", "rocm"] and other_device_type in ["cuda", "rocm"]:
- score = 0.1
- # Default expectation
- if Expectations.is_default(other):
- score = 0.5
- return score
- def find_expectation(self, properties: DeviceProperties = (None, None, None)) -> Any:
- """
- Find best matching expectation based on provided device properties. We score each expectation, and to
- distinguish between expectations with the same score, we use the major and minor version numbers, prioritizing
- most recent versions.
- """
- (result_key, result) = max(
- self.unpacked(),
- key=lambda x: (
- Expectations.score(properties, x[0]), # x[0] is a device properties tuple (device_type, major, minor)
- x[0][1] if x[0][1] is not None else -1, # This key is the major version, -1 if major is None
- x[0][2] if x[0][2] is not None else -1, # This key is the minor version, -1 if minor is None
- ),
- )
- if Expectations.score(properties, result_key) == 0:
- raise ValueError(f"No matching expectation found for {properties}")
- return result
- def __repr__(self):
- return f"{self.data}"
- def patch_torch_compile_force_graph():
- """
- Patch `torch.compile` to always use `fullgraph=True`.
- This is useful when some `torch.compile` tests are running with `fullgraph=False` and we want to be able to run
- them with `fullgraph=True` in some occasion (without introducing new tests) to make sure there is no graph break.
- After PR #40137, `CompileConfig.fullgraph` is `False` by default, this patch is necessary.
- """
- force_fullgraph = os.environ.get("TORCH_COMPILE_FORCE_FULLGRAPH", "")
- force_fullgraph = force_fullgraph.lower() in ("yes", "true", "on", "t", "y", "1")
- if force_fullgraph:
- import torch
- orig_method = torch.compile
- def patched(*args, **kwargs):
- # In `torch_compile`, all arguments except `model` is keyword only argument.
- kwargs["fullgraph"] = True
- return orig_method(*args, **kwargs)
- torch.compile = patched
- def _get_test_info():
- """
- Collect some information about the current test.
- For example, test full name, line number, stack, traceback, etc.
- """
- full_test_name = os.environ.get("PYTEST_CURRENT_TEST", "").split(" ")[0]
- test_file, test_class, test_name = full_test_name.split("::")
- # from the most recent frame to the top frame
- stack_from_inspect = inspect.stack()
- # but visit from the top frame to the most recent frame
- actual_test_file, _actual_test_class = test_file, test_class
- test_frame, test_obj, test_method = None, None, None
- for frame in reversed(stack_from_inspect):
- # if test_file in str(frame).replace(r"\\", "/"):
- # check frame's function + if it has `self` as locals; double check if self has the (function) name
- # TODO: Question: How about expanded?
- if (
- frame.function == test_name
- and "self" in frame.frame.f_locals
- and hasattr(frame.frame.f_locals["self"], test_name)
- ):
- # if test_name == frame.frame.f_locals["self"]._testMethodName:
- test_frame = frame
- # The test instance
- test_obj = frame.frame.f_locals["self"]
- # TODO: Do we get the (relative?) path or it's just a file name?
- # TODO: Does `test_obj` always have `tearDown` object?
- actual_test_file = frame.filename
- # TODO: check `test_method` will work used at the several places!
- test_method = getattr(test_obj, test_name)
- break
- if test_frame is not None:
- line_number = test_frame.lineno
- # The frame of `patched` being called (the one and the only one calling `_get_test_info`)
- # This is used to get the original method being patched in order to get the context.
- frame_of_patched_obj = None
- captured_frames = []
- to_capture = False
- # From the most outer (i.e. python's `runpy.py`) frame to most inner frame (i.e. the frame of this method)
- # Between `the test method being called` and `before entering `patched``.
- for frame in reversed(stack_from_inspect):
- if (
- frame.function == test_name
- and "self" in frame.frame.f_locals
- and hasattr(frame.frame.f_locals["self"], test_name)
- ):
- to_capture = True
- # TODO: check simply with the name is not robust.
- elif "patched" == frame.frame.f_code.co_name:
- frame_of_patched_obj = frame
- to_capture = False
- break
- if to_capture:
- captured_frames.append(frame)
- tb_next = None
- for frame_info in reversed(captured_frames):
- tb = types.TracebackType(tb_next, frame_info.frame, frame_info.frame.f_lasti, frame_info.frame.f_lineno)
- tb_next = tb
- test_traceback = tb
- origin_method_being_patched = frame_of_patched_obj.frame.f_locals["orig_method"]
- # An iterable of type `traceback.StackSummary` with each element of type `FrameSummary`
- stack = traceback.extract_stack()
- # The frame which calls `the original method being patched`
- caller_frame = None
- # From the most inner (i.e. recent) frame to the most outer frame
- for frame in reversed(stack):
- if origin_method_being_patched.__name__ in frame.line:
- caller_frame = frame
- caller_path = os.path.relpath(caller_frame.filename)
- caller_lineno = caller_frame.lineno
- test_lineno = line_number
- # Get the code context in the test function/method.
- from _pytest._code.source import Source
- with open(actual_test_file) as fp:
- s = fp.read()
- source = Source(s)
- test_code_context = "\n".join(source.getstatement(test_lineno - 1).lines)
- # Get the code context in the caller (to the patched function/method).
- with open(caller_path) as fp:
- s = fp.read()
- source = Source(s)
- caller_code_context = "\n".join(source.getstatement(caller_lineno - 1).lines)
- test_info = f"test:\n\n{full_test_name}\n\n{'-' * 80}\n\ntest context: {actual_test_file}:{test_lineno}\n\n{test_code_context}"
- test_info = f"{test_info}\n\n{'-' * 80}\n\ncaller context: {caller_path}:{caller_lineno}\n\n{caller_code_context}"
- return (
- full_test_name,
- test_file,
- test_lineno,
- test_obj,
- test_method,
- test_frame,
- test_traceback,
- test_code_context,
- caller_path,
- caller_lineno,
- caller_code_context,
- test_info,
- )
- def _get_call_arguments(code_context):
- """
- Analyze the positional and keyword arguments in a call expression.
- This will extract the expressions of the positional and kwyword arguments, and associate them to the positions and
- the keyword arugment names.
- """
- def get_argument_name(node):
- """Extract the name/expression from an AST node"""
- if isinstance(node, ast.Name):
- return node.id
- elif isinstance(node, ast.Attribute):
- return ast.unparse(node)
- elif isinstance(node, ast.Constant):
- return repr(node.value)
- else:
- return ast.unparse(node)
- indent = len(code_context) - len(code_context.lstrip())
- code_context = code_context.replace(" " * indent, "")
- try:
- # Parse the line
- tree = ast.parse(code_context, mode="eval")
- assert isinstance(tree.body, ast.Call)
- call_node = tree.body
- if call_node:
- result = {
- "positional_args": [],
- "keyword_args": {},
- "starargs": None, # *args
- "kwargs": None, # **kwargs
- }
- # Extract positional arguments
- for arg in call_node.args:
- arg_name = get_argument_name(arg)
- result["positional_args"].append(arg_name)
- # Extract keyword arguments
- for keyword in call_node.keywords:
- if keyword.arg is None:
- # This is **kwargs
- result["kwargs"] = get_argument_name(keyword.value)
- else:
- # Regular keyword argument
- arg_name = get_argument_name(keyword.value)
- result["keyword_args"][keyword.arg] = arg_name
- return result
- except (SyntaxError, AttributeError) as e:
- print(f"Error parsing: {e}")
- return None
- def _prepare_debugging_info(test_info, info):
- """Combine the information about the test and the call information to a patched function/method within it."""
- info = f"{test_info}\n\n{info}"
- p = os.path.join(os.environ.get("_PATCHED_TESTING_METHODS_OUTPUT_DIR", ""), "captured_info.txt")
- # TODO (ydshieh): This is not safe when we use pytest-xdist with more than 1 worker.
- with open(p, "a") as fp:
- fp.write(f"{info}\n\n{'=' * 120}\n\n")
- return info
- def _patched_tearDown(self, *args, **kwargs):
- """Used to report a test that has failures captured and handled by patched functions/methods (without re-raise).
- The patched functions/methods refer to the `patched` defined in `_patch_with_call_info`, which is applied to
- `torch.testing.assert_close` and `unittest.case.TestCase.assertEqual`.
- The objective is to avoid a failure being silence after being processed.
- If there is any failure that is not handled by the patched functions/methods, we add custom error message for them
- along with the usual pytest failure report.
- """
- # Check for regular failures before clearing:
- # when `_patched_tearDown` is called, the current test fails due to an assertion error given by a method being
- # patched by `_patch_with_call_info`. The patched method catches such an error and continue running the remaining
- # statements within the test. If the test fails with another error not handled by the patched methods, we don't let
- # pytest to fail and report it but the original failure (the first one that was processed) instead.
- # We still record those failures not handled by the patched methods, and add custom messages along with the usual
- # pytest failure report.
- regular_failures_info = []
- if hasattr(self, "_outcome") and self._outcome.errors:
- for error_entry in self._outcome.errors:
- test_instance, (exc_type, exc_obj, exc_tb) = error_entry
- # breakpoint()
- regular_failures_info.append(
- {
- "message": f"{str(exc_obj)}\n\n",
- "type": exc_type.__name__,
- "file": "test_modeling_vit.py",
- "line": 237, # get_deepest_frame_line(exc_tb) # Your helper function
- }
- )
- # Clear the regular failure (i.e. that is not from any of our patched assertion methods) from pytest's records.
- self._outcome.errors.clear()
- # reset back to the original tearDown method, so `_patched_tearDown` won't be run by the subsequent tests if they
- # have only test failures that are not handle by the patched methods (or no test failure at all).
- orig_tearDown = _patched_tearDown.orig_tearDown
- type(self).tearDown = orig_tearDown
- # Call the original tearDown
- orig_tearDown(self, *args, **kwargs)
- # Get the failure
- test_method = getattr(self, self._testMethodName)
- captured_failures = test_method.__func__.captured_failures[id(test_method)]
- # TODO: How could we show several exceptions in a sinigle test on the terminal? (Maybe not a good idea)
- captured_exceptions = captured_failures[0]["exception"]
- captured_traceback = captured_failures[0]["traceback"]
- # Show the cpatured information on the terminal.
- capturued_info = [x["info"] for x in captured_failures]
- capturued_info_str = f"\n\n{'=' * 80}\n\n".join(capturued_info)
- # Enhance the exception message if there were suppressed failures
- if regular_failures_info:
- enhanced_message = f"""{str(captured_exceptions)}
- {"=" * 80}
- Handled Failures: ({len(capturued_info)} handled):
- {"-" * 80}\n
- {capturued_info_str}
- {"=" * 80}
- Unhandled Failures: ({len(regular_failures_info)} unhandled):
- {"-" * 80}\n
- {", ".join(f"{info['type']}: {info['message']}{info['file']}:{info['line']}" for info in regular_failures_info)}
- {"-" * 80}
- Note: This failure occurred after other failures analyzed by the patched assertion methods.
- To see the full details, temporarily disable assertion patching.
- {"=" * 80}"""
- # Create new exception with enhanced message
- enhanced_exception = type(captured_exceptions)(enhanced_message)
- enhanced_exception.__cause__ = captured_exceptions.__cause__
- enhanced_exception.__context__ = captured_exceptions.__context__
- # Raise with your existing traceback reconstruction
- captured_exceptions = enhanced_exception
- # clean up the recorded status
- del test_method.__func__.captured_failures
- raise captured_exceptions.with_traceback(captured_traceback)
- def _patch_with_call_info(module_or_class, attr_name, _parse_call_info_func, target_args):
- """
- Patch a callerable `attr_name` of a module or class `module_or_class`.
- This will allow us to collect the call information, e.g. the argument names and values, also the literal expressions
- passed as the arguments.
- """
- orig_method = getattr(module_or_class, attr_name)
- if not callable(orig_method):
- return
- def patched(*args, **kwargs):
- # If the target callable is not called within a test, simply call it without modification.
- if not os.environ.get("PYTEST_CURRENT_TEST", ""):
- return orig_method(*args, **kwargs)
- try:
- orig_method(*args, **kwargs)
- except AssertionError as e:
- captured_exception = e
- # captured_traceback = e.__traceback__
- (
- full_test_name,
- test_file,
- test_lineno,
- test_obj,
- test_method,
- test_frame,
- test_traceback,
- test_code_context,
- caller_path,
- caller_lineno,
- caller_code_context,
- test_info,
- ) = _get_test_info()
- test_info = f"{test_info}\n\n{'-' * 80}\n\npatched method: {orig_method.__module__}.{orig_method.__name__}"
- call_argument_expressions = _get_call_arguments(caller_code_context)
- # This is specific
- info = _parse_call_info_func(orig_method, args, kwargs, call_argument_expressions, target_args)
- info = _prepare_debugging_info(test_info, info)
- # If the test is running in a CI environment (e.g. not a manual run), let's raise and fail the test, so it
- # behaves as usual.
- # On Github Actions or CircleCI, this is set automatically.
- # When running manually, it's the user to determine if to set it.
- # This is to avoid the patched function being called `with self.assertRaises(AssertionError):` and fails
- # because of the missing expected `AssertionError`.
- # TODO (ydshieh): If there is way to raise only when we are inside such context managers?
- # TODO (ydshieh): How not to record the failure if it happens inside `self.assertRaises(AssertionError)`?
- if os.getenv("CI") == "true":
- raise captured_exception.with_traceback(test_traceback)
- # Save this, so we can raise at the end of the current test
- captured_failure = {
- "result": "failed",
- "exception": captured_exception,
- "traceback": test_traceback,
- "info": info,
- }
- # Record the failure status and its information, so we can raise it later.
- # We are modifying the (unbound) function at class level: not its logic but only adding a new extra
- # attribute.
- if getattr(test_method.__func__, "captured_failures", None) is None:
- test_method.__func__.captured_failures = {}
- if id(test_method) not in test_method.__func__.captured_failures:
- test_method.__func__.captured_failures[id(test_method)] = []
- test_method.__func__.captured_failures[id(test_method)].append(captured_failure)
- # This modifies the `tearDown` which will be called after every tests, but we reset it back inside
- # `_patched_tearDown`.
- if not hasattr(type(test_obj).tearDown, "orig_tearDown"):
- orig_tearDown = type(test_obj).tearDown
- _patched_tearDown.orig_tearDown = orig_tearDown
- type(test_obj).tearDown = _patched_tearDown
- setattr(module_or_class, attr_name, patched)
- def _parse_call_info(func, args, kwargs, call_argument_expressions, target_args):
- """
- Prepare a string containing the call info to `func`, e.g. argument names/values/expressions.
- """
- signature = inspect.signature(func)
- signature_names = [param.name for param_name, param in signature.parameters.items()]
- # called as `self.method_name()` or `xxx.method_name()`.
- if len(args) == len(call_argument_expressions["positional_args"]) + 1:
- # We simply add "self" as the expression despite it might not be the actual argument name.
- # (This part is very unlikely what a user would be interest to know)
- call_argument_expressions["positional_args"] = ["self"] + call_argument_expressions["positional_args"]
- param_position_mapping = {param_name: idx for idx, param_name in enumerate(signature_names)}
- arg_info = {}
- for arg_name in target_args:
- if arg_name in kwargs:
- arg_value = kwargs[arg_name]
- arg_expr = call_argument_expressions["keyword_args"][arg_name]
- else:
- arg_pos = param_position_mapping[arg_name]
- arg_value = args[arg_pos]
- arg_expr = call_argument_expressions["positional_args"][arg_pos]
- arg_value_str = _format_py_obj(arg_value)
- arg_info[arg_name] = {"arg_expr": arg_expr, "arg_value_str": arg_value_str}
- info = ""
- for arg_name in arg_info:
- arg_expr, arg_value_str = arg_info[arg_name]["arg_expr"], arg_info[arg_name]["arg_value_str"]
- info += f"{'-' * 80}\n\nargument name: `{arg_name}`\nargument expression: `{arg_expr}`\n\nargument value:\n\n{arg_value_str}\n\n"
- # remove the trailing \n\n
- info = info[:-2]
- return info
- def patch_testing_methods_to_collect_info():
- """
- Patch some methods (`torch.testing.assert_close`, `unittest.case.TestCase.assertEqual`, etc).
- This will allow us to collect the call information, e.g. the argument names and values, also the literal expressions
- passed as the arguments.
- """
- p = os.path.join(os.environ.get("_PATCHED_TESTING_METHODS_OUTPUT_DIR", ""), "captured_info.txt")
- Path(p).unlink(missing_ok=True)
- if is_torch_available():
- import torch
- _patch_with_call_info(torch.testing, "assert_close", _parse_call_info, target_args=("actual", "expected"))
- _patch_with_call_info(unittest.case.TestCase, "assertEqual", _parse_call_info, target_args=("first", "second"))
- _patch_with_call_info(unittest.case.TestCase, "assertListEqual", _parse_call_info, target_args=("list1", "list2"))
- _patch_with_call_info(
- unittest.case.TestCase, "assertTupleEqual", _parse_call_info, target_args=("tuple1", "tuple2")
- )
- _patch_with_call_info(unittest.case.TestCase, "assertSetEqual", _parse_call_info, target_args=("set1", "set1"))
- _patch_with_call_info(unittest.case.TestCase, "assertDictEqual", _parse_call_info, target_args=("d1", "d2"))
- _patch_with_call_info(unittest.case.TestCase, "assertIn", _parse_call_info, target_args=("member", "container"))
- _patch_with_call_info(unittest.case.TestCase, "assertNotIn", _parse_call_info, target_args=("member", "container"))
- _patch_with_call_info(unittest.case.TestCase, "assertLess", _parse_call_info, target_args=("a", "b"))
- _patch_with_call_info(unittest.case.TestCase, "assertLessEqual", _parse_call_info, target_args=("a", "b"))
- _patch_with_call_info(unittest.case.TestCase, "assertGreater", _parse_call_info, target_args=("a", "b"))
- _patch_with_call_info(unittest.case.TestCase, "assertGreaterEqual", _parse_call_info, target_args=("a", "b"))
- def torchrun(script: str, nproc_per_node: int, is_torchrun: bool = True, env: Optional[dict] = None):
- """Run the `script` using `torchrun` command for multi-processing in a subprocess. Captures errors as necessary."""
- with tempfile.NamedTemporaryFile(mode="w+", suffix=".py") as tmp:
- tmp.write(script)
- tmp.flush()
- tmp.seek(0)
- if is_torchrun:
- cmd = (
- f"torchrun --nproc_per_node {nproc_per_node} --master_port {get_torch_dist_unique_port()} {tmp.name}"
- ).split()
- else:
- cmd = ["python3", tmp.name]
- # Note that the subprocess will be waited for here, and raise an error if not successful
- try:
- _ = subprocess.run(cmd, capture_output=True, env=env, text=True, check=True)
- except subprocess.CalledProcessError as e:
- raise Exception(f"The following error was captured: {e.stderr}")
- def _format_tensor(t, indent_level=0, sci_mode=None):
- """Format torch's tensor in a pretty way to be shown 👀 in the test report."""
- # `torch.testing.assert_close` could accept python int/float numbers.
- if not isinstance(t, torch.Tensor):
- t = torch.tensor(t)
- # Simply make the processing below simpler (not to hande both case)
- is_scalar = False
- if t.ndim == 0:
- t = torch.tensor([t])
- is_scalar = True
- # For scalar or one-dimensional tensor, keep it as one-line. If there is only one element along any dimension except
- # the last one, we also keep it as one-line.
- if t.ndim <= 1 or set(t.shape[0:-1]) == {1}:
- # Use `detach` to remove `grad_fn=<...>`, and use `to("cpu")` to remove `device='...'`
- t = t.detach().to("cpu")
- # We work directly with the string representation instead the tensor itself
- t_str = str(t)
- # remove `tensor( ... )` so keep only the content
- t_str = t_str.replace("tensor(", "").replace(")", "")
- # Sometimes there are extra spaces between `[` and the first digit of the first value (for alignment).
- # For example `[[ 0.06, -0.51], [-0.76, -0.49]]`. It may have multiple consecutive spaces.
- # Let's remove such extra spaces.
- while "[ " in t_str:
- t_str = t_str.replace("[ ", "[")
- # Put everything in a single line. We replace `\n` by a space ` ` so we still keep `,\n` as `, `.
- t_str = t_str.replace("\n", " ")
- # Remove repeated spaces (introduced by the previous step)
- while " " in t_str:
- t_str = t_str.replace(" ", " ")
- # remove leading `[` and `]` for scalar tensor
- if is_scalar:
- t_str = t_str[1:-1]
- t_str = " " * 4 * indent_level + t_str
- return t_str
- # Otherwise, we separte the representations of every elements along an outer dimension by new lines (after a `,`).
- # The representatioin each element is obtained by calling this function recursively with corrent `indent_level`.
- else:
- t_str = str(t)
- # (For the recursive calls should receive this value)
- if sci_mode is None:
- sci_mode = "e+" in t_str or "e-" in t_str
- # Use the original content to determine the scientific mode to use. This is required as the representation of
- # t[index] (computed below) maybe have different format regarding scientific notation.
- torch.set_printoptions(sci_mode=sci_mode)
- t_str = " " * 4 * indent_level + "[\n"
- # Keep the ending `,` for all outer dimensions whose representations are not put in one-line, even if there is
- # only one element along that dimension.
- t_str += ",\n".join(_format_tensor(x, indent_level=indent_level + 1, sci_mode=sci_mode) for x in t)
- t_str += ",\n" + " " * 4 * indent_level + "]"
- torch.set_printoptions(sci_mode=None)
- return t_str
- def _quote_string(s):
- """Given a string `s`, return a python literal expression that give `s` when it is used in a python source code.
- For example, if `s` is the string `abc`, the return value is `"abc"`.
- We choice double quotes over single quote despite `str(s)` would give `'abc'` instead of `"abc"`.
- """
- has_single_quote = "'" in s
- has_double_quote = '"' in s
- if has_single_quote and has_double_quote:
- # replace any double quote by the raw string r'\"'.
- s = s.replace('"', r"\"")
- return f'"{s}"'
- elif has_single_quote:
- return f'"{s}"'
- elif has_double_quote:
- return f"'{s}'"
- else:
- return f'"{s}"'
- def _format_py_obj(obj, indent=0, mode="", cache=None, prefix=""):
- """Format python objects of basic built-in type in a pretty way so we could copy-past them to code editor easily.
- Currently, this support int, float, str, list, tuple, and dict.
- It also works with `torch.Tensor` via calling `format_tesnor`.
- """
- if cache is None:
- cache = {}
- else:
- if (id(obj), indent, mode, prefix) in cache:
- return cache[(id(obj), indent, mode, prefix)]
- # special format method for `torch.Tensor`
- if str(obj.__class__) == "<class 'torch.Tensor'>":
- return _format_tensor(obj)
- elif obj.__class__.__name__ == "str":
- quoted_string = _quote_string(obj)
- # we don't want the newline being interpreted
- quoted_string = quoted_string.replace("\n", r"\n")
- output = quoted_string
- elif obj.__class__.__name__ in ["int", "float"]:
- # for float like `1/3`, we will get `0.3333333333333333`
- output = str(obj)
- elif obj.__class__.__name__ in ["list", "tuple", "dict"]:
- parenthesis = {
- "list": "[]",
- "tuple": "()",
- "dict": "{}",
- }
- p1, p2 = parenthesis[obj.__class__.__name__]
- elements_without_indent = []
- if isinstance(obj, dict):
- for idx, (k, v) in enumerate(obj.items()):
- last_element = idx == len(obj) - 1
- ok = _format_py_obj(k, indent=indent + 1, mode="one-line", cache=cache)
- ov = _format_py_obj(
- v,
- indent=indent + 1,
- mode=mode,
- cache=cache,
- prefix=ok.lstrip() + ": " + "," if not last_element else "",
- )
- # Each element could be multiple-line, but the indent of its first line is removed
- elements_without_indent.append(f"{ok.lstrip()}: {ov.lstrip()}")
- else:
- for idx, x in enumerate(obj):
- last_element = idx == len(obj) - 1
- o = _format_py_obj(
- x, indent=indent + 1, mode=mode, cache=cache, prefix="," if not last_element else ""
- )
- # Each element could be multiple-line, but the indent of its first line is removed
- elements_without_indent.append(o.lstrip())
- groups = []
- buf = []
- for idx, x in enumerate(elements_without_indent):
- buf.append(x)
- x_expanded = "\n" in buf[-1]
- not_last_element = idx != len(elements_without_indent) - 1
- # if `x` should be separated from subsequent elements
- should_finalize_x = x_expanded or len(f"{' ' * (4 * (indent + 1))}") + len(
- ", ".join(buf[-1:])
- ) > 120 - int(not_last_element)
- # if `buf[:-1]` (i.e. without `x`) should be combined together (into one line)
- should_finalize_buf = x_expanded
- # the recursive call returns single line, so we can use it to determine if we can fit the width limit
- if not should_finalize_buf:
- buf_not_fit_into_one_line = len(f"{' ' * (4 * (indent + 1))}") + len(", ".join(buf)) > 120 - int(
- not_last_element
- )
- should_finalize_buf = buf_not_fit_into_one_line
- # any element of iterable type need to be on its own line
- if (type(obj[idx]) if type(obj) is not dict else type(list(obj.values())[idx])) in [list, tuple, dict]:
- should_finalize_x = True
- should_finalize_buf = True
- # any type change --> need to be added after a new line
- prev_type = None
- current_type = type(obj[idx]) if type(obj) is not dict else type(list(obj.values())[idx])
- if len(buf) > 1:
- prev_type = type(obj[idx - 1]) if type(obj) is not dict else type(list(obj.values())[idx - 1])
- type_changed = current_type != prev_type
- if type_changed:
- should_finalize_buf = True
- # all elements in the buf are string --> don't finalize the buf by width limit
- if prev_type is None or (prev_type is str and current_type is str):
- should_finalize_buf = False
- # collect as many elements of string type as possible (without width limit).
- # These will be examined as a whole (if not fit into the width, each element would be in its own line)
- if current_type is str:
- should_finalize_x = False
- # `len(buf) == 1` or `obj[idx-1]` is a string
- if prev_type in [None, str]:
- should_finalize_buf = False
- if should_finalize_buf:
- orig_buf_len = len(buf)
- if orig_buf_len > 1:
- not_fit_into_one_line = None
- # all elements in `obj` that give `buf[:-1]` are string.
- if prev_type is str:
- # `-1` at the end: because buf[-2] is not the last element
- not_fit_into_one_line = len(f"{' ' * (4 * (indent + 1))}") + len(", ".join(buf[:-1])) > 120 - 1
- if not_fit_into_one_line:
- for x in buf[:-1]:
- groups.append([x])
- else:
- groups.append(buf[:-1])
- buf = buf[-1:]
- if should_finalize_x:
- groups.append(buf)
- buf = []
- # The last buf
- if len(buf) > 0:
- not_fit_into_one_line = None
- if current_type is str:
- # no `-1` at the end: because buf[-1] is the last element
- not_fit_into_one_line = len(f"{' ' * (4 * (indent + 1))}") + len(", ".join(buf)) > 120
- if not_fit_into_one_line:
- for x in buf:
- groups.append([x])
- else:
- groups.append(buf)
- output = f"{' ' * 4 * indent}{p1}\n"
- element_strings = [f"{' ' * (4 * (indent + 1))}" + ", ".join(buf) for buf in groups]
- output += ",\n".join(element_strings)
- output += f"\n{' ' * 4 * indent}{p2}"
- # if all elements are in one-line
- no_new_line_in_elements = all("\n" not in x for x in element_strings)
- # if yes, we can form a one-line representation of `obj`
- could_use_one_line = no_new_line_in_elements
- # if mode == "one-line", this function always returns one-line representation, so `no_new_line_in_elements`
- # will be `True`.
- if could_use_one_line:
- one_line_form = ", ".join([x.lstrip() for x in element_strings])
- one_line_form = f"{p1}{one_line_form}{p2}"
- if mode == "one-line":
- return output
- # check with the width limit
- could_use_one_line = len(f"{' ' * 4 * indent}") + len(prefix) + len(one_line_form) <= 120
- # extra conditions for returning one-line representation
- def use_one_line_repr(obj):
- # interable types
- if type(obj) in (list, tuple, dict):
- # get all types
- element_types = []
- if type(obj) is dict:
- element_types.extend(type(x) for x in obj.values())
- elif type(obj) in [list, tuple]:
- element_types.extend(type(x) for x in obj)
- # At least one element is of iterable type
- if any(x in (list, tuple, dict) for x in element_types):
- # If `obj` has more than one element and at least one of them is iterable --> no one line repr.
- if len(obj) > 1:
- return False
- # only one element that is iterable, but not the same type as `obj` --> no one line repr.
- if type(obj) is not type(obj[0]):
- return False
- # one-line repr. if possible, without width limit
- return no_new_line_in_elements
- # all elements are of simple types, but more than one type --> no one line repr.
- if len(set(element_types)) > 1:
- return False
- # all elements are of the same simple type
- if element_types[0] in [int, float]:
- # one-line repr. without width limit
- return no_new_line_in_elements
- elif element_types[0] is str:
- if len(obj) == 1:
- # one single string element --> one-line repr. without width limit
- return no_new_line_in_elements
- else:
- # multiple string elements --> one-line repr. if fit into width limit
- return could_use_one_line
- # simple types (int, flat, string)
- return True
- # width condition combined with specific mode conditions
- if use_one_line_repr(obj):
- output = f"{' ' * 4 * indent}{one_line_form}"
- cache[(id(obj), indent, mode, prefix)] = output
- return output
|