api.py 135 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095309630973098309931003101310231033104310531063107310831093110311131123113311431153116311731183119312031213122312331243125312631273128312931303131313231333134313531363137313831393140314131423143314431453146314731483149315031513152315331543155315631573158315931603161316231633164316531663167316831693170317131723173317431753176317731783179318031813182318331843185318631873188318931903191319231933194319531963197319831993200320132023203320432053206320732083209321032113212321332143215321632173218321932203221322232233224322532263227322832293230323132323233323432353236323732383239324032413242324332443245324632473248324932503251325232533254
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. # yapf: disable
  3. import datetime
  4. import fnmatch
  5. import functools
  6. import io
  7. import os
  8. import pickle
  9. import platform
  10. import re
  11. import shutil
  12. import tempfile
  13. import time
  14. import uuid
  15. import warnings
  16. from collections import defaultdict
  17. from http import HTTPStatus
  18. from http.cookiejar import CookieJar
  19. from os.path import expanduser
  20. from pathlib import Path
  21. from typing import (Any, BinaryIO, Dict, Iterable, List, Literal, Optional,
  22. Tuple, Union)
  23. from urllib.parse import urlencode
  24. import json
  25. import requests
  26. from requests import Session
  27. from requests.adapters import HTTPAdapter, Retry
  28. from requests.exceptions import HTTPError
  29. from tqdm.auto import tqdm
  30. from modelscope.hub.constants import (API_HTTP_CLIENT_MAX_RETRIES,
  31. API_HTTP_CLIENT_TIMEOUT,
  32. API_RESPONSE_FIELD_DATA,
  33. API_RESPONSE_FIELD_EMAIL,
  34. API_RESPONSE_FIELD_GIT_ACCESS_TOKEN,
  35. API_RESPONSE_FIELD_MESSAGE,
  36. API_RESPONSE_FIELD_USERNAME,
  37. DEFAULT_MAX_WORKERS,
  38. DEFAULT_MODELSCOPE_INTL_DOMAIN,
  39. MODELSCOPE_CLOUD_ENVIRONMENT,
  40. MODELSCOPE_CLOUD_USERNAME,
  41. MODELSCOPE_CREDENTIALS_PATH,
  42. MODELSCOPE_DOMAIN,
  43. MODELSCOPE_PREFER_AI_SITE,
  44. MODELSCOPE_REQUEST_ID,
  45. MODELSCOPE_URL_SCHEME, ONE_YEAR_SECONDS,
  46. REQUESTS_API_HTTP_METHOD,
  47. TEMPORARY_FOLDER_NAME,
  48. UPLOAD_BLOB_TQDM_DISABLE_THRESHOLD,
  49. UPLOAD_COMMIT_BATCH_SIZE,
  50. UPLOAD_MAX_FILE_COUNT,
  51. UPLOAD_MAX_FILE_COUNT_IN_DIR,
  52. UPLOAD_MAX_FILE_SIZE,
  53. UPLOAD_NORMAL_FILE_SIZE_TOTAL_LIMIT,
  54. UPLOAD_SIZE_THRESHOLD_TO_ENFORCE_LFS,
  55. VALID_SORT_KEYS, DatasetVisibility,
  56. Licenses, ModelVisibility, Visibility,
  57. VisibilityMap)
  58. from modelscope.hub.errors import (InvalidParameter, NotExistError,
  59. NotLoginException, RequestError,
  60. datahub_raise_on_error,
  61. handle_http_post_error,
  62. handle_http_response, is_ok,
  63. raise_for_http_status, raise_on_error)
  64. from modelscope.hub.git import GitCommandWrapper
  65. from modelscope.hub.info import DatasetInfo, ModelInfo
  66. from modelscope.hub.repository import Repository
  67. from modelscope.hub.utils.aigc import AigcModel
  68. from modelscope.hub.utils.utils import (add_content_to_file, get_domain,
  69. get_endpoint, get_readable_folder_size,
  70. get_release_datetime, is_env_true,
  71. model_id_to_group_owner_name)
  72. from modelscope.utils.constant import (DEFAULT_DATASET_REVISION,
  73. DEFAULT_MODEL_REVISION,
  74. DEFAULT_REPOSITORY_REVISION,
  75. MASTER_MODEL_BRANCH, META_FILES_FORMAT,
  76. REPO_TYPE_DATASET, REPO_TYPE_MODEL,
  77. REPO_TYPE_SUPPORT, ConfigFields,
  78. DatasetFormations, DatasetMetaFormats,
  79. DownloadChannel, DownloadMode,
  80. Frameworks, ModelFile, Tasks,
  81. VirgoDatasetConfig)
  82. from modelscope.utils.file_utils import get_file_hash, get_file_size
  83. from modelscope.utils.logger import get_logger
  84. from modelscope.utils.repo_utils import (DATASET_LFS_SUFFIX,
  85. DEFAULT_IGNORE_PATTERNS,
  86. MODEL_LFS_SUFFIX,
  87. CommitHistoryResponse, CommitInfo,
  88. CommitOperation, CommitOperationAdd,
  89. RepoUtils)
  90. from modelscope.utils.thread_utils import thread_executor
  91. logger = get_logger()
  92. class HubApi:
  93. """Model hub api interface.
  94. """
  95. def __init__(self,
  96. endpoint: Optional[str] = None,
  97. timeout=API_HTTP_CLIENT_TIMEOUT,
  98. max_retries=API_HTTP_CLIENT_MAX_RETRIES):
  99. """The ModelScope HubApi。
  100. Args:
  101. endpoint (str, optional): The modelscope server http|https address. Defaults to None.
  102. """
  103. self.endpoint = endpoint if endpoint is not None else get_endpoint()
  104. self.headers = {'user-agent': ModelScopeConfig.get_user_agent()}
  105. self.session = Session()
  106. retry = Retry(
  107. total=max_retries,
  108. read=2,
  109. connect=2,
  110. backoff_factor=1,
  111. status_forcelist=(500, 502, 503, 504),
  112. respect_retry_after_header=False,
  113. )
  114. adapter = HTTPAdapter(max_retries=retry)
  115. self.session.mount('http://', adapter)
  116. self.session.mount('https://', adapter)
  117. # set http timeout
  118. for method in REQUESTS_API_HTTP_METHOD:
  119. setattr(
  120. self.session, method,
  121. functools.partial(
  122. getattr(self.session, method),
  123. timeout=timeout))
  124. self.upload_checker = UploadingCheck()
  125. def _get_cookies(self, access_token: str):
  126. """
  127. Get jar cookies for authentication from access_token.
  128. Args:
  129. access_token (str): user access token on ModelScope.
  130. Returns:
  131. jar (CookieJar): cookies for authentication.
  132. """
  133. from requests.cookies import RequestsCookieJar
  134. from urllib.parse import urlparse
  135. domain: str = urlparse(self.endpoint).netloc if self.endpoint else get_domain()
  136. jar = RequestsCookieJar()
  137. jar.set('m_session_id',
  138. access_token,
  139. domain=domain,
  140. path='/')
  141. return jar
  142. def get_cookies(self, access_token, cookies_required: Optional[bool] = False):
  143. """
  144. Get cookies for authentication from local cache or access_token.
  145. Args:
  146. access_token (str): user access token on ModelScope
  147. cookies_required (bool): whether to raise error if no cookies found, defaults to `False`.
  148. Returns:
  149. cookies (CookieJar): cookies for authentication.
  150. Raises:
  151. ValueError: If no credentials found and cookies_required is True.
  152. """
  153. if access_token:
  154. cookies = self._get_cookies(access_token=access_token)
  155. else:
  156. cookies = ModelScopeConfig.get_cookies()
  157. if cookies is None and cookies_required:
  158. raise ValueError(
  159. 'No credentials found.'
  160. 'You can pass the `--token` argument, '
  161. 'or use HubApi().login(access_token=`your_sdk_token`). '
  162. 'Your token is available at https://modelscope.cn/my/myaccesstoken'
  163. )
  164. return cookies
  165. def login(
  166. self,
  167. access_token: Optional[str] = None,
  168. endpoint: Optional[str] = None
  169. ):
  170. """Login with your SDK access token, which can be obtained from
  171. https://www.modelscope.cn user center.
  172. Args:
  173. access_token (str): user access token on modelscope, set this argument or set `MODELSCOPE_API_TOKEN`.
  174. If neither of the tokens exist, login will directly return.
  175. endpoint: the endpoint to use, default to None to use endpoint specified in the class
  176. Returns:
  177. cookies: to authenticate yourself to ModelScope open-api
  178. git_token: token to access your git repository.
  179. Note:
  180. You only have to login once within 30 days.
  181. """
  182. if access_token is None:
  183. access_token = os.environ.get('MODELSCOPE_API_TOKEN')
  184. if not access_token:
  185. return None, None
  186. if not endpoint:
  187. endpoint = self.endpoint
  188. path = f'{endpoint}/api/v1/login'
  189. r = self.session.post(
  190. path,
  191. json={'AccessToken': access_token},
  192. headers=self.builder_headers(self.headers))
  193. raise_for_http_status(r)
  194. d = r.json()
  195. raise_on_error(d)
  196. token = d[API_RESPONSE_FIELD_DATA][API_RESPONSE_FIELD_GIT_ACCESS_TOKEN]
  197. cookies = r.cookies
  198. # save token and cookie
  199. ModelScopeConfig.save_token(token)
  200. ModelScopeConfig.save_cookies(cookies)
  201. ModelScopeConfig.save_user_info(
  202. d[API_RESPONSE_FIELD_DATA][API_RESPONSE_FIELD_USERNAME],
  203. d[API_RESPONSE_FIELD_DATA][API_RESPONSE_FIELD_EMAIL])
  204. return d[API_RESPONSE_FIELD_DATA][
  205. API_RESPONSE_FIELD_GIT_ACCESS_TOKEN], cookies
  206. def create_model(self,
  207. model_id: str,
  208. visibility: Optional[int] = ModelVisibility.PUBLIC,
  209. license: Optional[str] = Licenses.APACHE_V2,
  210. chinese_name: Optional[str] = None,
  211. original_model_id: Optional[str] = '',
  212. endpoint: Optional[str] = None,
  213. token: Optional[str] = None,
  214. aigc_model: Optional['AigcModel'] = None) -> str:
  215. """Create model repo at ModelScope Hub.
  216. Args:
  217. model_id (str): The model id in format {owner}/{name}
  218. visibility (int, optional): visibility of the model(1-private, 5-public), default 5.
  219. license (str, optional): license of the model, default apache-2.0.
  220. chinese_name (str, optional): chinese name of the model.
  221. original_model_id (str, optional): the base model id which this model is trained from
  222. endpoint: the endpoint to use, default to None to use endpoint specified in the class
  223. token (str, optional): access token for authentication
  224. aigc_model (AigcModel, optional): AigcModel instance for AIGC model creation.
  225. If provided, will create an AIGC model with automatic file upload.
  226. Refer to modelscope.hub.utils.aigc.AigcModel for details.
  227. Returns:
  228. str: URL of the created model repository
  229. Raises:
  230. InvalidParameter: If model_id is invalid or required AIGC parameters are missing.
  231. ValueError: If not login.
  232. Note:
  233. model_id = {owner}/{name}
  234. """
  235. if model_id is None:
  236. raise InvalidParameter('model_id is required!')
  237. # Get cookies for authentication.
  238. cookies = self.get_cookies(access_token=token, cookies_required=True)
  239. if not endpoint:
  240. endpoint = self.endpoint
  241. owner_or_group, name = model_id_to_group_owner_name(model_id)
  242. # Base body configuration
  243. body = {
  244. 'Path': owner_or_group,
  245. 'Name': name,
  246. 'ChineseName': chinese_name,
  247. 'Visibility': visibility,
  248. 'License': license,
  249. 'OriginalModelId': original_model_id,
  250. 'TrainId': os.environ.get('MODELSCOPE_TRAIN_ID', '')
  251. }
  252. # Set path based on model type
  253. if aigc_model is not None:
  254. # Use AIGC model endpoint
  255. path = f'{endpoint}/api/v1/models/aigc'
  256. # Best-effort pre-upload weights so server recognizes sha256 (use existing cookies)
  257. aigc_model.preupload_weights(cookies=cookies, headers=self.builder_headers(self.headers), endpoint=endpoint)
  258. # Add AIGC-specific fields to body
  259. body.update({
  260. 'TagShowName': aigc_model.tag,
  261. 'CoverImages': aigc_model.cover_images,
  262. 'AigcType': aigc_model.aigc_type,
  263. 'TagDescription': aigc_model.description,
  264. 'VisionFoundation': aigc_model.base_model_type,
  265. 'BaseModel': aigc_model.base_model_id or original_model_id,
  266. 'WeightsName': aigc_model.weight_filename,
  267. 'WeightsSha256': aigc_model.weight_sha256,
  268. 'WeightsSize': aigc_model.weight_size,
  269. 'ModelPath': aigc_model.model_path,
  270. 'TriggerWords': aigc_model.trigger_words,
  271. 'ModelSource': aigc_model.model_source,
  272. 'SubVisionFoundation': aigc_model.base_model_sub_type,
  273. })
  274. if aigc_model.official_tags:
  275. body['OfficialTags'] = aigc_model.official_tags
  276. else:
  277. # Use regular model endpoint
  278. path = f'{endpoint}/api/v1/models'
  279. headers = self.builder_headers(self.headers)
  280. intl_end = DEFAULT_MODELSCOPE_INTL_DOMAIN.split('.')[-1]
  281. if endpoint.rstrip('/').endswith(f'.{intl_end}'):
  282. headers['X-Modelscope-Accept-Language'] = 'en_US'
  283. r = self.session.post(
  284. path,
  285. json=body,
  286. cookies=cookies,
  287. headers=headers)
  288. raise_for_http_status(r)
  289. d = r.json()
  290. raise_on_error(d)
  291. model_repo_url = f'{endpoint}/models/{model_id}'
  292. # Upload model files for AIGC models
  293. if aigc_model is not None:
  294. aigc_model.upload_to_repo(self, model_id, token)
  295. return model_repo_url
  296. def create_model_tag(self,
  297. model_id: str,
  298. tag_name: str,
  299. endpoint: Optional[str] = None,
  300. token: Optional[str] = None,
  301. aigc_model: Optional['AigcModel'] = None) -> str:
  302. """Create a tag for a model at ModelScope Hub.
  303. Args:
  304. model_id (str): The model id in format {owner}/{name}
  305. tag_name (str): The tag name (e.g., "v1.0.0")
  306. endpoint: the endpoint to use, default to None to use endpoint specified in the class
  307. token (str, optional): access token for authentication
  308. aigc_model (AigcModel, optional): AigcModel instance for AIGC model tag creation.
  309. If provided, will create an AIGC model tag with automatic parameters.
  310. Refer to modelscope.hub.utils.aigc.AigcModel for details.
  311. Returns:
  312. str: URL of the created tag
  313. Raises:
  314. InvalidParameter: If model_id, tag_name, ref, or description is invalid.
  315. ValueError: If not login.
  316. Note:
  317. model_id = {owner}/{name}
  318. """
  319. if model_id is None:
  320. raise InvalidParameter('model_id is required!')
  321. if tag_name is None:
  322. raise InvalidParameter('tag_name is required!')
  323. if tag_name.lower() in ['main', 'master']:
  324. raise InvalidParameter(
  325. f'tag_name "{tag_name}" is not allowed. '
  326. f'Please use a different tag name (e.g., "v1.0", "v1.1", "latest"). '
  327. f'Reserved names: main, master'
  328. )
  329. # Get cookies for authentication.
  330. cookies = self.get_cookies(access_token=token, cookies_required=True)
  331. if not endpoint:
  332. endpoint = self.endpoint
  333. owner_or_group, name = model_id_to_group_owner_name(model_id)
  334. # Set path and body based on model type
  335. if aigc_model is not None:
  336. # Use AIGC model tag endpoint
  337. path = f'{endpoint}/api/v1/models/aigc/repo/tag'
  338. aigc_model.preupload_weights(cookies=cookies, headers=self.builder_headers(self.headers), endpoint=endpoint)
  339. # Base body for AIGC model tag
  340. body = {
  341. 'CoverImages': aigc_model.cover_images,
  342. 'Name': name,
  343. 'Path': owner_or_group,
  344. 'TagShowName': tag_name,
  345. 'WeightsName': aigc_model.weight_filename,
  346. 'WeightsSha256': aigc_model.weight_sha256,
  347. 'WeightsSize': aigc_model.weight_size,
  348. 'TriggerWords': aigc_model.trigger_words,
  349. 'AigcType': aigc_model.aigc_type,
  350. 'VisionFoundation': aigc_model.base_model_type
  351. }
  352. else:
  353. # Use regular model tag endpoint
  354. path = f'{endpoint}/api/v1/models/{model_id}/repo/tag'
  355. revision = 'master'
  356. body = {
  357. 'TagName': tag_name,
  358. 'Ref': revision
  359. }
  360. r = self.session.post(
  361. path,
  362. json=body,
  363. cookies=cookies,
  364. headers=self.builder_headers(self.headers))
  365. raise_for_http_status(r)
  366. d = r.json()
  367. raise_on_error(d)
  368. tag_url = f'{endpoint}/models/{model_id}/tags/{tag_name}'
  369. return tag_url
  370. def delete_model(self, model_id: str, endpoint: Optional[str] = None):
  371. """Delete model_id from ModelScope.
  372. Args:
  373. model_id (str): The model id.
  374. endpoint: the endpoint to use, default to None to use endpoint specified in the class
  375. Raises:
  376. ValueError: If not login.
  377. Note:
  378. model_id = {owner}/{name}
  379. """
  380. cookies = ModelScopeConfig.get_cookies()
  381. if not endpoint:
  382. endpoint = self.endpoint
  383. if cookies is None:
  384. raise ValueError('Token does not exist, please login first.')
  385. path = f'{endpoint}/api/v1/models/{model_id}'
  386. r = self.session.delete(path,
  387. cookies=cookies,
  388. headers=self.builder_headers(self.headers))
  389. raise_for_http_status(r)
  390. raise_on_error(r.json())
  391. def get_model_url(self, model_id: str, endpoint: Optional[str] = None):
  392. if not endpoint:
  393. endpoint = self.endpoint
  394. return f'{endpoint}/api/v1/models/{model_id}.git'
  395. def get_model(
  396. self,
  397. model_id: str,
  398. revision: Optional[str] = DEFAULT_MODEL_REVISION,
  399. endpoint: Optional[str] = None
  400. ) -> dict:
  401. """Get model information at ModelScope
  402. Args:
  403. model_id (str): The model id.
  404. revision (str optional): revision of model.
  405. endpoint: the endpoint to use, default to None to use endpoint specified in the class
  406. Returns:
  407. The model detail information.
  408. Raises:
  409. NotExistError: If the model is not exist, will throw NotExistError
  410. Note:
  411. model_id = {owner}/{name}
  412. """
  413. cookies = ModelScopeConfig.get_cookies()
  414. owner_or_group, name = model_id_to_group_owner_name(model_id)
  415. if not endpoint:
  416. endpoint = self.endpoint
  417. if revision:
  418. path = f'{endpoint}/api/v1/models/{owner_or_group}/{name}?Revision={revision}'
  419. else:
  420. path = f'{endpoint}/api/v1/models/{owner_or_group}/{name}'
  421. r = self.session.get(path, cookies=cookies,
  422. headers=self.builder_headers(self.headers))
  423. handle_http_response(r, logger, cookies, model_id)
  424. if r.status_code == HTTPStatus.OK:
  425. if is_ok(r.json()):
  426. return r.json()[API_RESPONSE_FIELD_DATA]
  427. else:
  428. raise NotExistError(r.json()[API_RESPONSE_FIELD_MESSAGE])
  429. else:
  430. raise_for_http_status(r)
  431. def get_endpoint_for_read(self,
  432. repo_id: str,
  433. *,
  434. repo_type: Optional[str] = None) -> str:
  435. """Get proper endpoint for read operation (such as download, list etc.)
  436. 1. If user has set MODELSCOPE_DOMAIN, construct endpoint with user-specified domain.
  437. If the repo does not exist on that endpoint, throw 404 error, otherwise return the endpoint.
  438. 2. If domain is not set, check existence of repo in cn-site and ai-site (intl version) respectively.
  439. Checking order is determined by MODELSCOPE_PREFER_AI_SITE.
  440. a. if MODELSCOPE_PREFER_AI_SITE is not set ,check cn-site first before ai-site (intl version)
  441. b. otherwise check ai-site before cn-site
  442. return the endpoint with which the given repo_id exists.
  443. if neither exists, throw 404 error
  444. """
  445. s = os.environ.get(MODELSCOPE_DOMAIN)
  446. if s is not None and s.strip() != '':
  447. endpoint = MODELSCOPE_URL_SCHEME + s
  448. try:
  449. self.repo_exists(repo_id=repo_id, repo_type=repo_type, endpoint=endpoint, re_raise=True)
  450. except Exception:
  451. logger.error(f'Repo {repo_id} does not exist on {endpoint}.')
  452. raise
  453. return endpoint
  454. check_cn_first = not is_env_true(MODELSCOPE_PREFER_AI_SITE)
  455. prefer_endpoint = get_endpoint(cn_site=check_cn_first)
  456. if not self.repo_exists(
  457. repo_id, repo_type=repo_type, endpoint=prefer_endpoint):
  458. alternative_endpoint = get_endpoint(cn_site=(not check_cn_first))
  459. logger.warning(f'Repo {repo_id} not exists on {prefer_endpoint}, '
  460. f'will try on alternative endpoint {alternative_endpoint}.')
  461. try:
  462. self.repo_exists(
  463. repo_id, repo_type=repo_type, endpoint=alternative_endpoint, re_raise=True)
  464. except Exception:
  465. logger.error(f'Repo {repo_id} not exists on either {prefer_endpoint} or {alternative_endpoint}')
  466. raise
  467. else:
  468. return alternative_endpoint
  469. else:
  470. return prefer_endpoint
  471. def model_info(self,
  472. repo_id: str,
  473. *,
  474. revision: Optional[str] = DEFAULT_MODEL_REVISION,
  475. endpoint: Optional[str] = None) -> ModelInfo:
  476. """Get model information including commit history.
  477. Args:
  478. repo_id (str): The model id in the format of
  479. ``namespace/model_name``.
  480. revision (str, optional): Specific revision of the model.
  481. Defaults to ``DEFAULT_MODEL_REVISION``.
  482. endpoint (str, optional): Hub endpoint to use. When ``None``,
  483. use the endpoint specified when initializing :class:`HubApi`.
  484. Returns:
  485. ModelInfo: The model detailed information returned by
  486. ModelScope Hub with commit history.
  487. """
  488. owner_or_group, _ = model_id_to_group_owner_name(repo_id)
  489. model_data = self.get_model(
  490. model_id=repo_id, revision=revision, endpoint=endpoint)
  491. commits = self.list_repo_commits(
  492. repo_id=repo_id, repo_type=REPO_TYPE_MODEL, revision=revision, endpoint=endpoint)
  493. siblings = self.get_model_files(
  494. model_id=repo_id, revision=revision, recursive=True, endpoint=endpoint)
  495. # Create ModelInfo from API response data
  496. model_info = ModelInfo(**model_data, commits=commits, author=owner_or_group, siblings=siblings)
  497. return model_info
  498. def dataset_info(self,
  499. repo_id: str,
  500. *,
  501. revision: Optional[str] = None,
  502. endpoint: Optional[str] = None) -> DatasetInfo:
  503. """Get dataset information including commit history.
  504. Args:
  505. repo_id (str): The dataset id in the format of
  506. ``namespace/dataset_name``.
  507. revision (str, optional): Specific revision of the dataset.
  508. Defaults to ``None``.
  509. endpoint (str, optional): Hub endpoint to use. When ``None``,
  510. use the endpoint specified when initializing :class:`HubApi`.
  511. Returns:
  512. DatasetInfo: The dataset detailed information returned by
  513. ModelScope Hub with commit history.
  514. """
  515. owner_or_group, _ = model_id_to_group_owner_name(repo_id)
  516. dataset_data = self.get_dataset(
  517. dataset_id=repo_id, revision=revision, endpoint=endpoint)
  518. commits = self.list_repo_commits(
  519. repo_id=repo_id, repo_type=REPO_TYPE_DATASET, revision=revision, endpoint=endpoint)
  520. siblings = self.get_dataset_files(
  521. repo_id=repo_id, revision=revision or DEFAULT_DATASET_REVISION, recursive=True, endpoint=endpoint)
  522. # Create DatasetInfo from API response data
  523. dataset_info = DatasetInfo(**dataset_data, commits=commits, author=owner_or_group, siblings=siblings)
  524. return dataset_info
  525. def repo_info(
  526. self,
  527. repo_id: str,
  528. *,
  529. repo_type: Optional[str] = REPO_TYPE_MODEL,
  530. revision: Optional[str] = DEFAULT_MODEL_REVISION,
  531. endpoint: Optional[str] = None
  532. ) -> Union[ModelInfo, DatasetInfo]:
  533. """Get repository information for models or datasets.
  534. Args:
  535. repo_id (str): The repository id in the format of
  536. ``namespace/repo_name``.
  537. revision (str, optional): Specific revision of the repository.
  538. Currently only effective for model repositories. Defaults to
  539. ``DEFAULT_MODEL_REVISION``.
  540. repo_type (str, optional): Type of the repository. Supported
  541. values are ``"model"`` and ``"dataset"``. If not provided,
  542. ``"model"`` is assumed.
  543. endpoint (str, optional): Hub endpoint to use. When ``None``,
  544. use the endpoint specified when initializing :class:`HubApi`.
  545. Returns:
  546. Union[ModelInfo, DatasetInfo]: The repository detailed information
  547. returned by ModelScope Hub.
  548. """
  549. if repo_type is None or repo_type == REPO_TYPE_MODEL:
  550. return self.model_info(repo_id=repo_id, revision=revision, endpoint=endpoint)
  551. if repo_type == REPO_TYPE_DATASET:
  552. return self.dataset_info(repo_id=repo_id, revision=revision, endpoint=endpoint)
  553. raise InvalidParameter(
  554. f'Arg repo_type {repo_type} not supported. Please choose from {REPO_TYPE_SUPPORT}.')
  555. def repo_exists(
  556. self,
  557. repo_id: str,
  558. *,
  559. repo_type: Optional[str] = None,
  560. endpoint: Optional[str] = None,
  561. re_raise: Optional[bool] = False,
  562. token: Optional[str] = None
  563. ) -> bool:
  564. """
  565. Checks if a repository exists on ModelScope
  566. Args:
  567. repo_id (`str`):
  568. A namespace (user or an organization) and a repo name separated
  569. by a `/`.
  570. repo_type (`str`, *optional*):
  571. `None` or `"model"` if getting repository info from a model. Default is `None`.
  572. TODO: support studio
  573. endpoint(`str`):
  574. None or specific endpoint to use, when None, use the default endpoint
  575. set in HubApi class (self.endpoint)
  576. re_raise(`bool`):
  577. raise exception when error
  578. token (`str`, *optional*): access token to use for checking existence.
  579. Returns:
  580. True if the repository exists, False otherwise.
  581. """
  582. if endpoint is None:
  583. endpoint = self.endpoint
  584. if (repo_type is not None) and repo_type.lower() not in REPO_TYPE_SUPPORT:
  585. raise Exception('Not support repo-type: %s' % repo_type)
  586. if (repo_id is None) or repo_id.count('/') != 1:
  587. raise Exception('Invalid repo_id: %s, must be of format namespace/name' % repo_type)
  588. cookies = self.get_cookies(access_token=token, cookies_required=False)
  589. owner_or_group, name = model_id_to_group_owner_name(repo_id)
  590. if (repo_type is not None) and repo_type.lower() == REPO_TYPE_DATASET:
  591. path = f'{endpoint}/api/v1/datasets/{owner_or_group}/{name}'
  592. else:
  593. path = f'{endpoint}/api/v1/models/{owner_or_group}/{name}'
  594. r = self.session.get(path, cookies=cookies,
  595. headers=self.builder_headers(self.headers))
  596. code = handle_http_response(r, logger, cookies, repo_id, False)
  597. if code == 200:
  598. return True
  599. elif code == 404:
  600. if re_raise:
  601. raise HTTPError(r)
  602. else:
  603. return False
  604. else:
  605. logger.warn(f'Check repo_exists return status code {code}.')
  606. raise Exception(
  607. 'Failed to check existence of repo: %s, make sure you have access authorization.'
  608. % repo_type)
  609. def delete_repo(self, repo_id: str, repo_type: str, endpoint: Optional[str] = None):
  610. """
  611. Delete a repository from ModelScope.
  612. Args:
  613. repo_id (`str`):
  614. A namespace (user or an organization) and a repo name separated
  615. by a `/`.
  616. repo_type (`str`):
  617. The type of the repository. Supported types are `model` and `dataset`.
  618. endpoint(`str`):
  619. The endpoint to use. If not provided, the default endpoint is `https://www.modelscope.cn`
  620. Could be set to `https://ai.modelscope.ai` for international version.
  621. """
  622. if not endpoint:
  623. endpoint = self.endpoint
  624. if repo_type == REPO_TYPE_DATASET:
  625. self.delete_dataset(repo_id, endpoint)
  626. elif repo_type == REPO_TYPE_MODEL:
  627. self.delete_model(repo_id, endpoint)
  628. else:
  629. raise Exception(f'Arg repo_type {repo_type} not supported.')
  630. logger.info(f'Repo {repo_id} deleted successfully.')
  631. @staticmethod
  632. def _create_default_config(model_dir):
  633. cfg_file = os.path.join(model_dir, ModelFile.CONFIGURATION)
  634. cfg = {
  635. ConfigFields.framework: Frameworks.torch,
  636. ConfigFields.task: Tasks.other,
  637. }
  638. with open(cfg_file, 'w') as file:
  639. json.dump(cfg, file)
  640. def push_model(self,
  641. model_id: str,
  642. model_dir: str,
  643. visibility: Optional[int] = ModelVisibility.PUBLIC,
  644. license: Optional[str] = Licenses.APACHE_V2,
  645. chinese_name: Optional[str] = None,
  646. commit_message: Optional[str] = 'upload model',
  647. tag: Optional[str] = None,
  648. revision: Optional[str] = DEFAULT_REPOSITORY_REVISION,
  649. original_model_id: Optional[str] = None,
  650. ignore_file_pattern: Optional[Union[List[str], str]] = None,
  651. lfs_suffix: Optional[Union[str, List[str]]] = None):
  652. warnings.warn(
  653. 'This function is deprecated and will be removed in future versions. '
  654. 'Please use git command directly or use HubApi().upload_folder instead',
  655. DeprecationWarning,
  656. stacklevel=2
  657. )
  658. """Upload model from a given directory to given repository. A valid model directory
  659. must contain a configuration.json file.
  660. This function upload the files in given directory to given repository. If the
  661. given repository is not exists in remote, it will automatically create it with
  662. given visibility, license and chinese_name parameters. If the revision is also
  663. not exists in remote repository, it will create a new branch for it.
  664. This function must be called before calling HubApi's login with a valid token
  665. which can be obtained from ModelScope's website.
  666. If any error, please upload via git commands.
  667. Args:
  668. model_id (str):
  669. The model id to be uploaded, caller must have write permission for it.
  670. model_dir(str):
  671. The Absolute Path of the finetune result.
  672. visibility(int, optional):
  673. Visibility of the new created model(1-private, 5-public). If the model is
  674. not exists in ModelScope, this function will create a new model with this
  675. visibility and this parameter is required. You can ignore this parameter
  676. if you make sure the model's existence.
  677. license(`str`, defaults to `None`):
  678. License of the new created model(see License). If the model is not exists
  679. in ModelScope, this function will create a new model with this license
  680. and this parameter is required. You can ignore this parameter if you
  681. make sure the model's existence.
  682. chinese_name(`str`, *optional*, defaults to `None`):
  683. chinese name of the new created model.
  684. commit_message(`str`, *optional*, defaults to `None`):
  685. commit message of the push request.
  686. tag(`str`, *optional*, defaults to `None`):
  687. The tag on this commit
  688. revision (`str`, *optional*, default to DEFAULT_MODEL_REVISION):
  689. which branch to push. If the branch is not exists, It will create a new
  690. branch and push to it.
  691. original_model_id (str, optional): The base model id which this model is trained from
  692. ignore_file_pattern (`Union[List[str], str]`, optional): The file pattern to ignore uploading
  693. lfs_suffix (`List[str]`, optional): File types to use LFS to manage. examples: '*.safetensors'.
  694. Raises:
  695. InvalidParameter: Parameter invalid.
  696. NotLoginException: Not login
  697. ValueError: No configuration.json
  698. Exception: Create failed.
  699. """
  700. if model_id is None:
  701. raise InvalidParameter('model_id cannot be empty!')
  702. if model_dir is None:
  703. raise InvalidParameter('model_dir cannot be empty!')
  704. if not os.path.exists(model_dir) or os.path.isfile(model_dir):
  705. raise InvalidParameter('model_dir must be a valid directory.')
  706. cfg_file = os.path.join(model_dir, ModelFile.CONFIGURATION)
  707. if not os.path.exists(cfg_file):
  708. logger.warning(
  709. f'No {ModelFile.CONFIGURATION} file found in {model_dir}, creating a default one.')
  710. HubApi._create_default_config(model_dir)
  711. cookies = ModelScopeConfig.get_cookies()
  712. if cookies is None:
  713. raise NotLoginException('Must login before upload!')
  714. files_to_save = os.listdir(model_dir)
  715. folder_size = get_readable_folder_size(model_dir)
  716. if ignore_file_pattern is None:
  717. ignore_file_pattern = []
  718. if isinstance(ignore_file_pattern, str):
  719. ignore_file_pattern = [ignore_file_pattern]
  720. if visibility is None or license is None:
  721. raise InvalidParameter('Visibility and License cannot be empty for new model.')
  722. if not self.repo_exists(model_id):
  723. logger.info('Creating new model [%s]' % model_id)
  724. self.create_model(
  725. model_id=model_id,
  726. visibility=visibility,
  727. license=license,
  728. chinese_name=chinese_name,
  729. original_model_id=original_model_id)
  730. tmp_dir = os.path.join(model_dir, TEMPORARY_FOLDER_NAME) # make temporary folder
  731. git_wrapper = GitCommandWrapper()
  732. logger.info(f'Pushing folder {model_dir} as model {model_id}.')
  733. logger.info(f'Total folder size {folder_size}, this may take a while depending on actual pushing size...')
  734. try:
  735. repo = Repository(model_dir=tmp_dir, clone_from=model_id)
  736. branches = git_wrapper.get_remote_branches(tmp_dir)
  737. if revision not in branches:
  738. logger.info('Creating new branch %s' % revision)
  739. git_wrapper.new_branch(tmp_dir, revision)
  740. git_wrapper.checkout(tmp_dir, revision)
  741. files_in_repo = os.listdir(tmp_dir)
  742. for f in files_in_repo:
  743. if f[0] != '.':
  744. src = os.path.join(tmp_dir, f)
  745. if os.path.isfile(src):
  746. os.remove(src)
  747. else:
  748. shutil.rmtree(src, ignore_errors=True)
  749. for f in files_to_save:
  750. if f[0] != '.':
  751. if any([re.search(pattern, f) is not None for pattern in ignore_file_pattern]):
  752. continue
  753. src = os.path.join(model_dir, f)
  754. if os.path.isdir(src):
  755. shutil.copytree(src, os.path.join(tmp_dir, f))
  756. else:
  757. shutil.copy(src, tmp_dir)
  758. if not commit_message:
  759. date = datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S')
  760. commit_message = '[automsg] push model %s to hub at %s' % (
  761. model_id, date)
  762. if lfs_suffix is not None:
  763. lfs_suffix_list = [lfs_suffix] if isinstance(lfs_suffix, str) else lfs_suffix
  764. for suffix in lfs_suffix_list:
  765. repo.add_lfs_type(suffix)
  766. repo.push(
  767. commit_message=commit_message,
  768. local_branch=revision,
  769. remote_branch=revision)
  770. if tag is not None:
  771. repo.tag_and_push(tag, tag)
  772. logger.info(f'Successfully push folder {model_dir} to remote repo [{model_id}].')
  773. except Exception:
  774. raise
  775. finally:
  776. shutil.rmtree(tmp_dir, ignore_errors=True)
  777. def list_models(self,
  778. owner_or_group: str,
  779. page_number: Optional[int] = 1,
  780. page_size: Optional[int] = 10,
  781. endpoint: Optional[str] = None) -> dict:
  782. """List models in owner or group.
  783. Args:
  784. owner_or_group(str): owner or group.
  785. page_number(int, optional): The page number, default: 1
  786. page_size(int, optional): The page size, default: 10
  787. endpoint: the endpoint to use, default to None to use endpoint specified in the class
  788. Raises:
  789. RequestError: The request error.
  790. Returns:
  791. dict: {"models": "list of models", "TotalCount": total_number_of_models_in_owner_or_group}
  792. """
  793. cookies = ModelScopeConfig.get_cookies()
  794. if not endpoint:
  795. endpoint = self.endpoint
  796. path = f'{endpoint}/api/v1/models/'
  797. r = self.session.put(
  798. path,
  799. data='{"Path":"%s", "PageNumber":%s, "PageSize": %s}' %
  800. (owner_or_group, page_number, page_size),
  801. cookies=cookies,
  802. headers=self.builder_headers(self.headers))
  803. handle_http_response(r, logger, cookies, owner_or_group)
  804. if r.status_code == HTTPStatus.OK:
  805. if is_ok(r.json()):
  806. data = r.json()[API_RESPONSE_FIELD_DATA]
  807. return data
  808. else:
  809. raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE])
  810. else:
  811. raise_for_http_status(r)
  812. return None
  813. def list_datasets(self,
  814. owner_or_group: str,
  815. *,
  816. page_number: Optional[int] = 1,
  817. page_size: Optional[int] = 10,
  818. sort: Optional[str] = None,
  819. search: Optional[str] = None,
  820. endpoint: Optional[str] = None,
  821. ) -> dict:
  822. """List datasets via OpenAPI with pagination, filtering and sorting.
  823. Args:
  824. owner_or_group (str): Search by dataset authors (including organizations and individuals).
  825. page_number (int, optional): The page number. Defaults to 1.
  826. page_size (int, optional): The page size. Defaults to 10.
  827. sort (str, optional): Sort key. If not provided, the server's default sorting is used.
  828. choose from ['default', 'downloads', 'likes', 'last_modified'].
  829. search (str, optional): Search by substring keywords in the dataset's Chinese name,
  830. English name, and authors (including organizations and individuals).
  831. endpoint (str, optional): Hub endpoint to use. When None, use the endpoint specified in the class.
  832. Returns:
  833. dict: The OpenAPI data payload, e.g.
  834. {
  835. "datasets": [...],
  836. "total_count": int,
  837. "page_number": int,
  838. "page_size": int
  839. }
  840. """
  841. if not endpoint:
  842. endpoint = self.endpoint
  843. path = f'{endpoint}/openapi/v1/datasets'
  844. # Build query params
  845. params: Dict[str, Any] = {
  846. 'page_number': page_number,
  847. 'page_size': page_size,
  848. }
  849. if sort:
  850. if sort not in VALID_SORT_KEYS:
  851. raise InvalidParameter(
  852. f'Invalid sort key: {sort}. Supported sort keys: {list(VALID_SORT_KEYS)}')
  853. params['sort'] = sort
  854. if search:
  855. params['search'] = search
  856. if owner_or_group:
  857. params['author'] = owner_or_group
  858. cookies = ModelScopeConfig.get_cookies()
  859. headers = self.builder_headers(self.headers)
  860. r = self.session.get(
  861. path,
  862. params=params,
  863. cookies=cookies,
  864. headers=headers
  865. )
  866. raise_for_http_status(r)
  867. resp = r.json()
  868. # OpenAPI success schema
  869. if resp.get('success') is True and 'data' in resp:
  870. return resp['data']
  871. else:
  872. # Fallback for unexpected schema
  873. msg = resp.get('message') or 'Failed to list datasets'
  874. raise RequestError(msg)
  875. def _check_cookie(self, use_cookies: Union[bool, CookieJar] = False) -> CookieJar: # noqa
  876. cookies = None
  877. if isinstance(use_cookies, CookieJar):
  878. cookies = use_cookies
  879. elif isinstance(use_cookies, bool):
  880. cookies = ModelScopeConfig.get_cookies()
  881. if use_cookies and cookies is None:
  882. raise ValueError('Token does not exist, please login first.')
  883. return cookies
  884. def list_model_revisions(
  885. self,
  886. model_id: str,
  887. cutoff_timestamp: Optional[int] = None,
  888. use_cookies: Union[bool, CookieJar] = False) -> List[str]:
  889. """Get model branch and tags.
  890. Args:
  891. model_id (str): The model id
  892. cutoff_timestamp (int): Tags created before the cutoff will be included.
  893. The timestamp is represented by the seconds elapsed from the epoch time.
  894. use_cookies (Union[bool, CookieJar], optional): If is cookieJar, we will use this cookie, if True,
  895. will load cookie from local. Defaults to False.
  896. Returns:
  897. Tuple[List[str], List[str]]: Return list of branch name and tags
  898. """
  899. tags_details = self.list_model_revisions_detail(model_id=model_id,
  900. cutoff_timestamp=cutoff_timestamp,
  901. use_cookies=use_cookies)
  902. tags = [x['Revision'] for x in tags_details
  903. ] if tags_details else []
  904. return tags
  905. def list_model_revisions_detail(
  906. self,
  907. model_id: str,
  908. cutoff_timestamp: Optional[int] = None,
  909. use_cookies: Union[bool, CookieJar] = False,
  910. endpoint: Optional[str] = None) -> List[str]:
  911. """Get model branch and tags.
  912. Args:
  913. model_id (str): The model id
  914. cutoff_timestamp (int): Tags created before the cutoff will be included.
  915. The timestamp is represented by the seconds elapsed from the epoch time.
  916. use_cookies (Union[bool, CookieJar], optional): If is cookieJar, we will use this cookie, if True,
  917. will load cookie from local. Defaults to False.
  918. endpoint: the endpoint to use, default to None to use endpoint specified in the class
  919. Returns:
  920. Tuple[List[str], List[str]]: Return list of branch name and tags
  921. """
  922. cookies = self._check_cookie(use_cookies)
  923. if cutoff_timestamp is None:
  924. cutoff_timestamp = get_release_datetime()
  925. if not endpoint:
  926. endpoint = self.endpoint
  927. path = f'{endpoint}/api/v1/models/{model_id}/revisions?EndTime=%s' % cutoff_timestamp
  928. r = self.session.get(path, cookies=cookies,
  929. headers=self.builder_headers(self.headers))
  930. handle_http_response(r, logger, cookies, model_id)
  931. d = r.json()
  932. raise_on_error(d)
  933. info = d[API_RESPONSE_FIELD_DATA]
  934. # tags returned from backend are guaranteed to be ordered by create-time
  935. return info['RevisionMap']['Tags']
  936. def get_branch_tag_detail(self, details, name):
  937. for item in details:
  938. if item['Revision'] == name:
  939. return item
  940. return None
  941. def get_valid_revision_detail(self,
  942. model_id: str,
  943. revision=None,
  944. cookies: Optional[CookieJar] = None,
  945. endpoint: Optional[str] = None):
  946. if not endpoint:
  947. endpoint = self.endpoint
  948. release_timestamp = get_release_datetime()
  949. current_timestamp = int(round(datetime.datetime.now().timestamp()))
  950. # for active development in library codes (non-release-branches), release_timestamp
  951. # is set to be a far-away-time-in-the-future, to ensure that we shall
  952. # get the master-HEAD version from model repo by default (when no revision is provided)
  953. all_branches_detail, all_tags_detail = self.get_model_branches_and_tags_details(
  954. model_id, use_cookies=False if cookies is None else cookies, endpoint=endpoint)
  955. all_branches = [x['Revision'] for x in all_branches_detail] if all_branches_detail else []
  956. all_tags = [x['Revision'] for x in all_tags_detail] if all_tags_detail else []
  957. if release_timestamp > current_timestamp + ONE_YEAR_SECONDS:
  958. if revision is None:
  959. revision = MASTER_MODEL_BRANCH
  960. logger.info(
  961. 'Model revision not specified, using default [%s] version.'
  962. % revision)
  963. if revision not in all_branches and revision not in all_tags:
  964. raise NotExistError('The model: %s has no revision : %s .' % (model_id, revision))
  965. revision_detail = self.get_branch_tag_detail(all_tags_detail, revision)
  966. if revision_detail is None:
  967. revision_detail = self.get_branch_tag_detail(all_branches_detail, revision)
  968. logger.debug('Development mode use revision: %s' % revision)
  969. else:
  970. if revision is not None and revision in all_branches:
  971. revision_detail = self.get_branch_tag_detail(all_branches_detail, revision)
  972. return revision_detail
  973. if len(all_tags_detail) == 0: # use no revision use master as default.
  974. if revision is None or revision == MASTER_MODEL_BRANCH:
  975. revision = MASTER_MODEL_BRANCH
  976. else:
  977. raise NotExistError('The model: %s has no revision: %s !' % (model_id, revision))
  978. revision_detail = self.get_branch_tag_detail(all_branches_detail, revision)
  979. else:
  980. if revision is None: # user not specified revision, use latest revision before release time
  981. revisions_detail = [x for x in
  982. all_tags_detail if
  983. x['CreatedAt'] <= release_timestamp] if all_tags_detail else [] # noqa E501
  984. if len(revisions_detail) > 0:
  985. revision = revisions_detail[0]['Revision'] # use latest revision before release time.
  986. revision_detail = revisions_detail[0]
  987. else:
  988. revision = MASTER_MODEL_BRANCH
  989. revision_detail = self.get_branch_tag_detail(all_branches_detail, revision)
  990. vl = '[%s]' % ','.join(all_tags)
  991. logger.warning('Model revision should be specified from revisions: %s' % (vl))
  992. logger.warning('Model revision not specified, use revision: %s' % revision)
  993. else:
  994. # use user-specified revision
  995. if revision not in all_tags:
  996. if revision == MASTER_MODEL_BRANCH:
  997. logger.warning('Using the master branch is fragile, please use it with caution!')
  998. revision_detail = self.get_branch_tag_detail(all_branches_detail, revision)
  999. else:
  1000. vl = '[%s]' % ','.join(all_tags)
  1001. raise NotExistError('The model: %s has no revision: %s valid are: %s!' %
  1002. (model_id, revision, vl))
  1003. else:
  1004. revision_detail = self.get_branch_tag_detail(all_tags_detail, revision)
  1005. logger.info('Use user-specified model revision: %s' % revision)
  1006. return revision_detail
  1007. def get_valid_revision(self,
  1008. model_id: str,
  1009. revision=None,
  1010. cookies: Optional[CookieJar] = None,
  1011. endpoint: Optional[str] = None):
  1012. return self.get_valid_revision_detail(model_id=model_id,
  1013. revision=revision,
  1014. cookies=cookies,
  1015. endpoint=endpoint)['Revision']
  1016. def get_model_branches_and_tags_details(
  1017. self,
  1018. model_id: str,
  1019. use_cookies: Union[bool, CookieJar] = False,
  1020. endpoint: Optional[str] = None
  1021. ) -> Tuple[List[str], List[str]]:
  1022. """Get model branch and tags.
  1023. Args:
  1024. model_id (str): The model id
  1025. use_cookies (Union[bool, CookieJar], optional): If is cookieJar, we will use this cookie, if True,
  1026. will load cookie from local. Defaults to False.
  1027. endpoint: the endpoint to use, default to None to use endpoint specified in the class
  1028. Returns:
  1029. Tuple[List[str], List[str]]: Return list of branch name and tags
  1030. """
  1031. cookies = self._check_cookie(use_cookies)
  1032. if not endpoint:
  1033. endpoint = self.endpoint
  1034. path = f'{endpoint}/api/v1/models/{model_id}/revisions'
  1035. r = self.session.get(path, cookies=cookies,
  1036. headers=self.builder_headers(self.headers))
  1037. handle_http_response(r, logger, cookies, model_id)
  1038. d = r.json()
  1039. raise_on_error(d)
  1040. info = d[API_RESPONSE_FIELD_DATA]
  1041. return info['RevisionMap']['Branches'], info['RevisionMap']['Tags']
  1042. def get_model_branches_and_tags(
  1043. self,
  1044. model_id: str,
  1045. use_cookies: Union[bool, CookieJar] = False,
  1046. ) -> Tuple[List[str], List[str]]:
  1047. """Get model branch and tags.
  1048. Args:
  1049. model_id (str): The model id
  1050. use_cookies (Union[bool, CookieJar], optional): If is cookieJar, we will use this cookie, if True,
  1051. will load cookie from local. Defaults to False.
  1052. Returns:
  1053. Tuple[List[str], List[str]]: Return list of branch name and tags
  1054. """
  1055. branches_detail, tags_detail = self.get_model_branches_and_tags_details(model_id=model_id,
  1056. use_cookies=use_cookies)
  1057. branches = [x['Revision'] for x in branches_detail
  1058. ] if branches_detail else []
  1059. tags = [x['Revision'] for x in tags_detail
  1060. ] if tags_detail else []
  1061. return branches, tags
  1062. def get_model_files(self,
  1063. model_id: str,
  1064. revision: Optional[str] = DEFAULT_MODEL_REVISION,
  1065. root: Optional[str] = None,
  1066. recursive: Optional[bool] = False,
  1067. use_cookies: Union[bool, CookieJar] = False,
  1068. headers: Optional[dict] = {},
  1069. endpoint: Optional[str] = None) -> List[dict]:
  1070. """List the models files.
  1071. Args:
  1072. model_id (str): The model id
  1073. revision (Optional[str], optional): The branch or tag name.
  1074. root (Optional[str], optional): The root path. Defaults to None.
  1075. recursive (Optional[bool], optional): Is recursive list files. Defaults to False.
  1076. use_cookies (Union[bool, CookieJar], optional): If is cookieJar, we will use this cookie, if True,
  1077. will load cookie from local. Defaults to False.
  1078. headers: request headers
  1079. endpoint: the endpoint to use, default to None to use endpoint specified in the class
  1080. Returns:
  1081. List[dict]: Model file list.
  1082. """
  1083. if not endpoint:
  1084. endpoint = self.endpoint
  1085. if revision:
  1086. path = '%s/api/v1/models/%s/repo/files?Revision=%s&Recursive=%s' % (
  1087. endpoint, model_id, revision, recursive)
  1088. else:
  1089. path = '%s/api/v1/models/%s/repo/files?Recursive=%s' % (
  1090. endpoint, model_id, recursive)
  1091. cookies = self._check_cookie(use_cookies)
  1092. if root is not None:
  1093. path = path + f'&Root={root}'
  1094. headers = self.headers if headers is None else headers
  1095. headers['X-Request-ID'] = str(uuid.uuid4().hex)
  1096. r = self.session.get(
  1097. path, cookies=cookies, headers=headers)
  1098. handle_http_response(r, logger, cookies, model_id)
  1099. d = r.json()
  1100. raise_on_error(d)
  1101. files = []
  1102. if not d[API_RESPONSE_FIELD_DATA]['Files']:
  1103. logger.warning(f'No files found in model {model_id} at revision {revision}.')
  1104. return files
  1105. for file in d[API_RESPONSE_FIELD_DATA]['Files']:
  1106. if file['Name'] == '.gitignore' or file['Name'] == '.gitattributes':
  1107. continue
  1108. files.append(file)
  1109. return files
  1110. def file_exists(
  1111. self,
  1112. repo_id: str,
  1113. filename: str,
  1114. *,
  1115. revision: Optional[str] = None,
  1116. ):
  1117. """Get if the specified file exists
  1118. Args:
  1119. repo_id (`str`): The repo id to use
  1120. filename (`str`): The queried filename, if the file exists in a sub folder,
  1121. please pass <sub-folder-name>/<file-name>
  1122. revision (`Optional[str]`): The repo revision
  1123. Returns:
  1124. The query result in bool value
  1125. """
  1126. cookies = ModelScopeConfig.get_cookies()
  1127. files = self.get_model_files(
  1128. repo_id,
  1129. recursive=True,
  1130. revision=revision,
  1131. use_cookies=False if cookies is None else cookies,
  1132. )
  1133. files = [file['Path'] for file in files]
  1134. return filename in files
  1135. def create_dataset(self,
  1136. dataset_name: str,
  1137. namespace: str,
  1138. chinese_name: Optional[str] = '',
  1139. license: Optional[str] = Licenses.APACHE_V2,
  1140. visibility: Optional[int] = DatasetVisibility.PUBLIC,
  1141. description: Optional[str] = '',
  1142. endpoint: Optional[str] = None, ) -> str:
  1143. if dataset_name is None or namespace is None:
  1144. raise InvalidParameter('dataset_name and namespace are required!')
  1145. cookies = ModelScopeConfig.get_cookies()
  1146. if cookies is None:
  1147. raise ValueError('Token does not exist, please login first.')
  1148. if not endpoint:
  1149. endpoint = self.endpoint
  1150. path = f'{endpoint}/api/v1/datasets'
  1151. files = {
  1152. 'Name': (None, dataset_name),
  1153. 'ChineseName': (None, chinese_name),
  1154. 'Owner': (None, namespace),
  1155. 'License': (None, license),
  1156. 'Visibility': (None, visibility),
  1157. 'Description': (None, description)
  1158. }
  1159. r = self.session.post(
  1160. path,
  1161. files=files,
  1162. cookies=cookies,
  1163. headers=self.builder_headers(self.headers),
  1164. )
  1165. handle_http_post_error(r, path, files)
  1166. raise_on_error(r.json())
  1167. dataset_repo_url = f'{endpoint}/datasets/{namespace}/{dataset_name}'
  1168. logger.info(f'Create dataset success: {dataset_repo_url}')
  1169. return dataset_repo_url
  1170. def delete_dataset(self, dataset_id: str, endpoint: Optional[str] = None):
  1171. cookies = ModelScopeConfig.get_cookies()
  1172. if not endpoint:
  1173. endpoint = self.endpoint
  1174. if cookies is None:
  1175. raise ValueError('Token does not exist, please login first.')
  1176. path = f'{endpoint}/api/v1/datasets/{dataset_id}'
  1177. r = self.session.delete(path,
  1178. cookies=cookies,
  1179. headers=self.builder_headers(self.headers))
  1180. raise_for_http_status(r)
  1181. raise_on_error(r.json())
  1182. def get_dataset_id_and_type(self, dataset_name: str, namespace: str, endpoint: Optional[str] = None):
  1183. """ Get the dataset id and type. """
  1184. if not endpoint:
  1185. endpoint = self.endpoint
  1186. datahub_url = f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}'
  1187. cookies = ModelScopeConfig.get_cookies()
  1188. r = self.session.get(datahub_url, cookies=cookies)
  1189. resp = r.json()
  1190. datahub_raise_on_error(datahub_url, resp, r)
  1191. dataset_id = resp['Data']['Id']
  1192. dataset_type = resp['Data']['Type']
  1193. return dataset_id, dataset_type
  1194. def list_repo_tree(self,
  1195. dataset_name: str,
  1196. namespace: str,
  1197. revision: str,
  1198. root_path: str,
  1199. recursive: bool = True,
  1200. page_number: int = 1,
  1201. page_size: int = 100,
  1202. endpoint: Optional[str] = None):
  1203. """
  1204. @deprecated: Use `get_dataset_files` instead.
  1205. """
  1206. warnings.warn('The function `list_repo_tree` is deprecated, use `get_dataset_files` instead.',
  1207. DeprecationWarning)
  1208. dataset_hub_id, dataset_type = self.get_dataset_id_and_type(
  1209. dataset_name=dataset_name, namespace=namespace, endpoint=endpoint)
  1210. recursive = 'True' if recursive else 'False'
  1211. if not endpoint:
  1212. endpoint = self.endpoint
  1213. datahub_url = f'{endpoint}/api/v1/datasets/{dataset_hub_id}/repo/tree'
  1214. params = {'Revision': revision if revision else 'master',
  1215. 'Root': root_path if root_path else '/', 'Recursive': recursive,
  1216. 'PageNumber': page_number, 'PageSize': page_size}
  1217. cookies = ModelScopeConfig.get_cookies()
  1218. r = self.session.get(datahub_url, params=params, cookies=cookies)
  1219. resp = r.json()
  1220. datahub_raise_on_error(datahub_url, resp, r)
  1221. return resp
  1222. def list_repo_commits(self,
  1223. repo_id: str,
  1224. *,
  1225. repo_type: Optional[str] = REPO_TYPE_MODEL,
  1226. revision: Optional[str] = DEFAULT_REPOSITORY_REVISION,
  1227. page_number: int = 1,
  1228. page_size: int = 50,
  1229. endpoint: Optional[str] = None):
  1230. """
  1231. Get the commit history for a repository.
  1232. Args:
  1233. repo_id (str): The repository id, in the format of `namespace/repo_name`.
  1234. repo_type (Optional[str]): The type of the repository. Supported types are `model` and `dataset`.
  1235. revision (str): The branch or tag name. Defaults to `DEFAULT_REPOSITORY_REVISION`.
  1236. page_number (int): The page number for pagination. Defaults to 1.
  1237. page_size (int): The number of commits per page. Defaults to 50.
  1238. endpoint (Optional[str]): The endpoint to use, defaults to None to use the endpoint specified in the class.
  1239. Returns:
  1240. CommitHistoryResponse: The commit history response.
  1241. Examples:
  1242. >>> from modelscope.hub.api import HubApi
  1243. >>> api = HubApi()
  1244. >>> commit_history = api.list_repo_commits('meituan/Meeseeks')
  1245. >>> print(f"Total commits: {commit_history.total_count}")
  1246. >>> for commit in commit_history.commits:
  1247. ... print(f"{commit.short_id}: {commit.title}")
  1248. """
  1249. from datasets.utils.file_utils import is_relative_path
  1250. if is_relative_path(repo_id) and repo_id.count('/') == 1:
  1251. _owner, _dataset_name = repo_id.split('/')
  1252. else:
  1253. raise ValueError(f'Invalid repo_id: {repo_id} !')
  1254. if not endpoint:
  1255. endpoint = self.endpoint
  1256. commits_url = f'{endpoint}/api/v1/{repo_type}s/{repo_id}/commits' if repo_type else \
  1257. f'{endpoint}/api/v1/models/{repo_id}/commits'
  1258. params = {
  1259. 'Ref': revision or DEFAULT_REPOSITORY_REVISION,
  1260. 'PageNumber': page_number,
  1261. 'PageSize': page_size
  1262. }
  1263. cookies = ModelScopeConfig.get_cookies()
  1264. try:
  1265. r = self.session.get(commits_url, params=params,
  1266. cookies=cookies, headers=self.builder_headers(self.headers))
  1267. raise_for_http_status(r)
  1268. resp = r.json()
  1269. raise_on_error(resp)
  1270. if resp.get('Code') == HTTPStatus.OK:
  1271. return CommitHistoryResponse.from_api_response(resp)
  1272. except requests.exceptions.RequestException as e:
  1273. raise Exception(f'Failed to get repository commits for {repo_id}: {str(e)}')
  1274. def get_dataset_files(self,
  1275. repo_id: str,
  1276. *,
  1277. revision: str = DEFAULT_REPOSITORY_REVISION,
  1278. root_path: str = '/',
  1279. recursive: bool = True,
  1280. page_number: int = 1,
  1281. page_size: int = 100,
  1282. endpoint: Optional[str] = None):
  1283. """
  1284. Get the dataset files.
  1285. Args:
  1286. repo_id (str): The repository id, in the format of `namespace/dataset_name`.
  1287. revision (str): The branch or tag name. Defaults to `DEFAULT_REPOSITORY_REVISION`.
  1288. root_path (str): The root path to list. Defaults to '/'.
  1289. recursive (bool): Whether to list recursively. Defaults to True.
  1290. page_number (int): The page number for pagination. Defaults to 1.
  1291. page_size (int): The number of items per page. Defaults to 100.
  1292. endpoint (Optional[str]): The endpoint to use, defaults to None to use the endpoint specified in the class.
  1293. Returns:
  1294. List: The response containing the dataset repository tree information.
  1295. e.g. [{'CommitId': None, 'CommitMessage': '...', 'Size': 0, 'Type': 'tree'}, ...]
  1296. """
  1297. from datasets.utils.file_utils import is_relative_path
  1298. if is_relative_path(repo_id) and repo_id.count('/') == 1:
  1299. _owner, _dataset_name = repo_id.split('/')
  1300. else:
  1301. raise ValueError(f'Invalid repo_id: {repo_id} !')
  1302. dataset_hub_id, dataset_type = self.get_dataset_id_and_type(
  1303. dataset_name=_dataset_name, namespace=_owner, endpoint=endpoint)
  1304. if not endpoint:
  1305. endpoint = self.endpoint
  1306. datahub_url = f'{endpoint}/api/v1/datasets/{dataset_hub_id}/repo/tree'
  1307. params = {
  1308. 'Revision': revision,
  1309. 'Root': root_path,
  1310. 'Recursive': 'True' if recursive else 'False',
  1311. 'PageNumber': page_number,
  1312. 'PageSize': page_size
  1313. }
  1314. cookies = ModelScopeConfig.get_cookies()
  1315. r = self.session.get(datahub_url, params=params, cookies=cookies)
  1316. resp = r.json()
  1317. datahub_raise_on_error(datahub_url, resp, r)
  1318. return resp['Data']['Files']
  1319. def get_dataset(
  1320. self,
  1321. dataset_id: str,
  1322. revision: Optional[str] = DEFAULT_REPOSITORY_REVISION,
  1323. endpoint: Optional[str] = None
  1324. ):
  1325. """
  1326. Get the dataset information.
  1327. Args:
  1328. dataset_id (str): The dataset id.
  1329. revision (Optional[str]): The revision of the dataset.
  1330. endpoint (Optional[str]): The endpoint to use, defaults to None to use the endpoint specified in the class.
  1331. Returns:
  1332. dict: The dataset information.
  1333. """
  1334. cookies = ModelScopeConfig.get_cookies()
  1335. if not endpoint:
  1336. endpoint = self.endpoint
  1337. if revision:
  1338. path = f'{endpoint}/api/v1/datasets/{dataset_id}?Revision={revision}'
  1339. else:
  1340. path = f'{endpoint}/api/v1/datasets/{dataset_id}'
  1341. r = self.session.get(
  1342. path, cookies=cookies, headers=self.builder_headers(self.headers))
  1343. raise_for_http_status(r)
  1344. resp = r.json()
  1345. datahub_raise_on_error(path, resp, r)
  1346. return resp[API_RESPONSE_FIELD_DATA]
  1347. def get_dataset_meta_file_list(self, dataset_name: str, namespace: str,
  1348. dataset_id: str, revision: str, endpoint: Optional[str] = None):
  1349. """ Get the meta file-list of the dataset. """
  1350. if not endpoint:
  1351. endpoint = self.endpoint
  1352. datahub_url = f'{endpoint}/api/v1/datasets/{dataset_id}/repo/tree?Revision={revision}'
  1353. cookies = ModelScopeConfig.get_cookies()
  1354. r = self.session.get(datahub_url,
  1355. cookies=cookies,
  1356. headers=self.builder_headers(self.headers))
  1357. resp = r.json()
  1358. datahub_raise_on_error(datahub_url, resp, r)
  1359. file_list = resp['Data']
  1360. if file_list is None:
  1361. raise NotExistError(
  1362. f'The modelscope dataset [dataset_name = {dataset_name}, namespace = {namespace}, '
  1363. f'version = {revision}] dose not exist')
  1364. file_list = file_list['Files']
  1365. return file_list
  1366. @staticmethod
  1367. def dump_datatype_file(dataset_type: int, meta_cache_dir: str):
  1368. """
  1369. Dump the data_type as a local file, in order to get the dataset
  1370. formation without calling the datahub.
  1371. More details, please refer to the class
  1372. `modelscope.utils.constant.DatasetFormations`.
  1373. """
  1374. dataset_type_file_path = os.path.join(meta_cache_dir,
  1375. f'{str(dataset_type)}{DatasetFormations.formation_mark_ext.value}')
  1376. with open(dataset_type_file_path, 'w') as fp:
  1377. fp.write('*** Automatically-generated file, do not modify ***')
  1378. def get_dataset_meta_files_local_paths(self, dataset_name: str,
  1379. namespace: str,
  1380. revision: str,
  1381. meta_cache_dir: str, dataset_type: int, file_list: list,
  1382. endpoint: Optional[str] = None):
  1383. local_paths = defaultdict(list)
  1384. dataset_formation = DatasetFormations(dataset_type)
  1385. dataset_meta_format = DatasetMetaFormats[dataset_formation]
  1386. cookies = ModelScopeConfig.get_cookies()
  1387. # Dump the data_type as a local file
  1388. HubApi.dump_datatype_file(dataset_type=dataset_type, meta_cache_dir=meta_cache_dir)
  1389. if not endpoint:
  1390. endpoint = self.endpoint
  1391. for file_info in file_list:
  1392. file_path = file_info['Path']
  1393. extension = os.path.splitext(file_path)[-1]
  1394. if extension in dataset_meta_format:
  1395. datahub_url = f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?' \
  1396. f'Revision={revision}&FilePath={file_path}'
  1397. r = self.session.get(datahub_url, cookies=cookies)
  1398. raise_for_http_status(r)
  1399. local_path = os.path.join(meta_cache_dir, file_path)
  1400. if os.path.exists(local_path):
  1401. logger.warning(
  1402. f"Reusing dataset {dataset_name}'s python file ({local_path})"
  1403. )
  1404. local_paths[extension].append(local_path)
  1405. continue
  1406. with open(local_path, 'wb') as f:
  1407. f.write(r.content)
  1408. local_paths[extension].append(local_path)
  1409. return local_paths, dataset_formation
  1410. @staticmethod
  1411. def fetch_meta_files_from_url(url, out_path, chunk_size=1024, mode=DownloadMode.REUSE_DATASET_IF_EXISTS):
  1412. """
  1413. Fetch the meta-data files from the url, e.g. csv/jsonl files.
  1414. """
  1415. import hashlib
  1416. from tqdm.auto import tqdm
  1417. import pandas as pd
  1418. out_path = os.path.join(out_path, hashlib.md5(url.encode(encoding='UTF-8')).hexdigest())
  1419. if mode == DownloadMode.FORCE_REDOWNLOAD and os.path.exists(out_path):
  1420. os.remove(out_path)
  1421. if os.path.exists(out_path):
  1422. logger.info(f'Reusing cached meta-data file: {out_path}')
  1423. return out_path
  1424. cookies = ModelScopeConfig.get_cookies()
  1425. # Make the request and get the response content as TextIO
  1426. logger.info('Loading meta-data file ...')
  1427. response = requests.get(url, cookies=cookies, stream=True)
  1428. total_size = int(response.headers.get('content-length', 0))
  1429. progress = tqdm(total=total_size, dynamic_ncols=True)
  1430. def get_chunk(resp):
  1431. chunk_data = []
  1432. for data in resp.iter_lines():
  1433. data = data.decode('utf-8')
  1434. chunk_data.append(data)
  1435. if len(chunk_data) >= chunk_size:
  1436. yield chunk_data
  1437. chunk_data = []
  1438. yield chunk_data
  1439. iter_num = 0
  1440. with open(out_path, 'a') as f:
  1441. for chunk in get_chunk(response):
  1442. progress.update(len(chunk))
  1443. if url.endswith('jsonl'):
  1444. chunk = [json.loads(line) for line in chunk if line.strip()]
  1445. if len(chunk) == 0:
  1446. continue
  1447. if iter_num == 0:
  1448. with_header = True
  1449. else:
  1450. with_header = False
  1451. chunk_df = pd.DataFrame(chunk)
  1452. chunk_df.to_csv(f, index=False, header=with_header, escapechar='\\')
  1453. iter_num += 1
  1454. else:
  1455. # csv or others
  1456. for line in chunk:
  1457. f.write(line + '\n')
  1458. progress.close()
  1459. return out_path
  1460. def get_dataset_file_url(
  1461. self,
  1462. file_name: str,
  1463. dataset_name: str,
  1464. namespace: str,
  1465. revision: Optional[str] = DEFAULT_DATASET_REVISION,
  1466. view: Optional[bool] = False,
  1467. extension_filter: Optional[bool] = True,
  1468. endpoint: Optional[str] = None):
  1469. if not file_name or not dataset_name or not namespace:
  1470. raise ValueError('Args (file_name, dataset_name, namespace) cannot be empty!')
  1471. # Note: make sure the FilePath is the last parameter in the url
  1472. params: dict = {'Source': 'SDK', 'Revision': revision, 'FilePath': file_name, 'View': view}
  1473. params: str = urlencode(params)
  1474. if not endpoint:
  1475. endpoint = self.endpoint
  1476. file_url = f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?{params}'
  1477. return file_url
  1478. # if extension_filter:
  1479. # if os.path.splitext(file_name)[-1] in META_FILES_FORMAT:
  1480. # file_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?'\
  1481. # f'Revision={revision}&FilePath={file_name}'
  1482. # else:
  1483. # file_url = file_name
  1484. # return file_url
  1485. # else:
  1486. # return file_url
  1487. def get_dataset_file_url_origin(
  1488. self,
  1489. file_name: str,
  1490. dataset_name: str,
  1491. namespace: str,
  1492. revision: Optional[str] = DEFAULT_DATASET_REVISION,
  1493. endpoint: Optional[str] = None):
  1494. if not endpoint:
  1495. endpoint = self.endpoint
  1496. if file_name and os.path.splitext(file_name)[-1] in META_FILES_FORMAT:
  1497. file_name = f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?' \
  1498. f'Revision={revision}&FilePath={file_name}'
  1499. return file_name
  1500. def get_dataset_access_config(
  1501. self,
  1502. dataset_name: str,
  1503. namespace: str,
  1504. revision: Optional[str] = DEFAULT_DATASET_REVISION,
  1505. endpoint: Optional[str] = None):
  1506. if not endpoint:
  1507. endpoint = self.endpoint
  1508. datahub_url = f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}/' \
  1509. f'ststoken?Revision={revision}'
  1510. return self.datahub_remote_call(datahub_url)
  1511. def get_dataset_access_config_session(
  1512. self,
  1513. dataset_name: str,
  1514. namespace: str,
  1515. check_cookie: bool,
  1516. revision: Optional[str] = DEFAULT_DATASET_REVISION,
  1517. endpoint: Optional[str] = None):
  1518. if not endpoint:
  1519. endpoint = self.endpoint
  1520. datahub_url = f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}/' \
  1521. f'ststoken?Revision={revision}'
  1522. if check_cookie:
  1523. cookies = self._check_cookie(use_cookies=True)
  1524. else:
  1525. cookies = ModelScopeConfig.get_cookies()
  1526. r = self.session.get(
  1527. url=datahub_url,
  1528. cookies=cookies,
  1529. headers=self.builder_headers(self.headers))
  1530. resp = r.json()
  1531. raise_on_error(resp)
  1532. return resp['Data']
  1533. def get_virgo_meta(self, dataset_id: str, version: int = 1) -> dict:
  1534. """
  1535. Get virgo dataset meta info.
  1536. """
  1537. virgo_endpoint = os.environ.get(VirgoDatasetConfig.env_virgo_endpoint, '')
  1538. if not virgo_endpoint:
  1539. raise RuntimeError(f'Virgo endpoint is not set in env: {VirgoDatasetConfig.env_virgo_endpoint}')
  1540. virgo_dataset_url = f'{virgo_endpoint}/data/set/download'
  1541. cookies = requests.utils.dict_from_cookiejar(ModelScopeConfig.get_cookies())
  1542. dataset_info = dict(
  1543. dataSetId=dataset_id,
  1544. dataSetVersion=version
  1545. )
  1546. data = dict(
  1547. data=dataset_info,
  1548. )
  1549. r = self.session.post(url=virgo_dataset_url,
  1550. json=data,
  1551. cookies=cookies,
  1552. headers=self.builder_headers(self.headers),
  1553. timeout=900)
  1554. resp = r.json()
  1555. if resp['code'] != 0:
  1556. raise RuntimeError(f'Failed to get virgo dataset: {resp}')
  1557. return resp['data']
  1558. def get_dataset_access_config_for_unzipped(self,
  1559. dataset_name: str,
  1560. namespace: str,
  1561. revision: str,
  1562. zip_file_name: str,
  1563. endpoint: Optional[str] = None):
  1564. if not endpoint:
  1565. endpoint = self.endpoint
  1566. datahub_url = f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}'
  1567. cookies = ModelScopeConfig.get_cookies()
  1568. r = self.session.get(url=datahub_url, cookies=cookies,
  1569. headers=self.builder_headers(self.headers))
  1570. resp = r.json()
  1571. # get visibility of the dataset
  1572. raise_on_error(resp)
  1573. data = resp['Data']
  1574. visibility = VisibilityMap.get(data['Visibility'])
  1575. datahub_sts_url = f'{datahub_url}/ststoken?Revision={revision}'
  1576. r_sts = self.session.get(url=datahub_sts_url, cookies=cookies,
  1577. headers=self.builder_headers(self.headers))
  1578. resp_sts = r_sts.json()
  1579. raise_on_error(resp_sts)
  1580. data_sts = resp_sts['Data']
  1581. file_dir = visibility + '-unzipped' + '/' + namespace + '_' + dataset_name + '_' + zip_file_name
  1582. data_sts['Dir'] = file_dir
  1583. return data_sts
  1584. def list_oss_dataset_objects(self, dataset_name, namespace, max_limit,
  1585. is_recursive, is_filter_dir, revision, endpoint: Optional[str] = None):
  1586. if not endpoint:
  1587. endpoint = self.endpoint
  1588. url = f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}/oss/tree/?' \
  1589. f'MaxLimit={max_limit}&Revision={revision}&Recursive={is_recursive}&FilterDir={is_filter_dir}'
  1590. cookies = ModelScopeConfig.get_cookies()
  1591. resp = self.session.get(url=url, cookies=cookies, timeout=1800)
  1592. resp = resp.json()
  1593. raise_on_error(resp)
  1594. resp = resp['Data']
  1595. return resp
  1596. def delete_oss_dataset_object(self, object_name: str, dataset_name: str,
  1597. namespace: str, revision: str, endpoint: Optional[str] = None) -> str:
  1598. if not object_name or not dataset_name or not namespace or not revision:
  1599. raise ValueError('Args cannot be empty!')
  1600. if not endpoint:
  1601. endpoint = self.endpoint
  1602. url = f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}/oss?Path={object_name}&Revision={revision}'
  1603. cookies = ModelScopeConfig.get_cookies()
  1604. resp = self.session.delete(url=url, cookies=cookies)
  1605. resp = resp.json()
  1606. raise_on_error(resp)
  1607. resp = resp['Message']
  1608. return resp
  1609. def delete_oss_dataset_dir(self, object_name: str, dataset_name: str,
  1610. namespace: str, revision: str, endpoint: Optional[str] = None) -> str:
  1611. if not object_name or not dataset_name or not namespace or not revision:
  1612. raise ValueError('Args cannot be empty!')
  1613. if not endpoint:
  1614. endpoint = self.endpoint
  1615. url = f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}/oss/prefix?Prefix={object_name}/' \
  1616. f'&Revision={revision}'
  1617. cookies = ModelScopeConfig.get_cookies()
  1618. resp = self.session.delete(url=url, cookies=cookies)
  1619. resp = resp.json()
  1620. raise_on_error(resp)
  1621. resp = resp['Message']
  1622. return resp
  1623. def datahub_remote_call(self, url):
  1624. cookies = ModelScopeConfig.get_cookies()
  1625. r = self.session.get(
  1626. url,
  1627. cookies=cookies,
  1628. headers={'user-agent': ModelScopeConfig.get_user_agent()})
  1629. resp = r.json()
  1630. datahub_raise_on_error(url, resp, r)
  1631. return resp['Data']
  1632. def dataset_download_statistics(self, dataset_name: str, namespace: str,
  1633. use_streaming: bool = False, endpoint: Optional[str] = None) -> None:
  1634. is_ci_test = os.getenv('CI_TEST') == 'True'
  1635. if not endpoint:
  1636. endpoint = self.endpoint
  1637. if dataset_name and namespace and not is_ci_test and not use_streaming:
  1638. try:
  1639. cookies = ModelScopeConfig.get_cookies()
  1640. # Download count
  1641. download_count_url = f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}/download/increase'
  1642. download_count_resp = self.session.post(download_count_url, cookies=cookies,
  1643. headers=self.builder_headers(self.headers))
  1644. raise_for_http_status(download_count_resp)
  1645. # Download uv
  1646. channel = DownloadChannel.LOCAL.value
  1647. user_name = ''
  1648. if MODELSCOPE_CLOUD_ENVIRONMENT in os.environ:
  1649. channel = os.environ[MODELSCOPE_CLOUD_ENVIRONMENT]
  1650. if MODELSCOPE_CLOUD_USERNAME in os.environ:
  1651. user_name = os.environ[MODELSCOPE_CLOUD_USERNAME]
  1652. download_uv_url = f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}/download/uv/' \
  1653. f'{channel}?user={user_name}'
  1654. download_uv_resp = self.session.post(download_uv_url, cookies=cookies,
  1655. headers=self.builder_headers(self.headers))
  1656. download_uv_resp = download_uv_resp.json()
  1657. raise_on_error(download_uv_resp)
  1658. except Exception as e:
  1659. logger.error(e)
  1660. def builder_headers(self, headers):
  1661. return {MODELSCOPE_REQUEST_ID: str(uuid.uuid4().hex),
  1662. **headers}
  1663. def get_file_base_path(self, repo_id: str, endpoint: Optional[str] = None) -> str:
  1664. _namespace, _dataset_name = repo_id.split('/')
  1665. if not endpoint:
  1666. endpoint = self.endpoint
  1667. return f'{endpoint}/api/v1/datasets/{_namespace}/{_dataset_name}/repo?'
  1668. # return f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?Revision={revision}&FilePath='
  1669. def create_repo(
  1670. self,
  1671. repo_id: str,
  1672. *,
  1673. token: Union[str, bool, None] = None,
  1674. visibility: Optional[str] = Visibility.PUBLIC,
  1675. repo_type: Optional[str] = REPO_TYPE_MODEL,
  1676. chinese_name: Optional[str] = None,
  1677. license: Optional[str] = Licenses.APACHE_V2,
  1678. endpoint: Optional[str] = None,
  1679. exist_ok: Optional[bool] = False,
  1680. create_default_config: Optional[bool] = True,
  1681. aigc_model: Optional[AigcModel] = None,
  1682. **kwargs,
  1683. ) -> str:
  1684. """
  1685. Create a repository on the ModelScope Hub.
  1686. Args:
  1687. repo_id (str): The repo id in the format of `owner_name/repo_name`.
  1688. token (Union[str, bool, None]): The access token.
  1689. visibility (Optional[str]): The visibility of the repo,
  1690. could be `public`, `private`, `internal`, default to `public`.
  1691. repo_type (Optional[str]): The repo type, default to `model`.
  1692. chinese_name (Optional[str]): The Chinese name of the repo.
  1693. license (Optional[str]): The license of the repo, default to `apache-2.0`.
  1694. endpoint (Optional[str]): The endpoint to use.
  1695. In the format of `https://www.modelscope.cn` or 'https://www.modelscope.ai'
  1696. exist_ok (Optional[bool]): If the repo exists, whether to return the repo url directly.
  1697. create_default_config (Optional[bool]): If True, create a default configuration file in the model repo.
  1698. **kwargs: The additional arguments.
  1699. Returns:
  1700. str: The repo url.
  1701. """
  1702. if not repo_id:
  1703. raise ValueError('Repo id cannot be empty!')
  1704. if not endpoint:
  1705. endpoint = self.endpoint
  1706. self.login(access_token=token, endpoint=endpoint)
  1707. repo_exists: bool = self.repo_exists(repo_id, repo_type=repo_type, endpoint=endpoint, token=token)
  1708. if repo_exists:
  1709. if exist_ok:
  1710. repo_url: str = f'{endpoint}/{repo_type}s/{repo_id}'
  1711. logger.warning(f'Repo {repo_id} already exists, got repo url: {repo_url}')
  1712. return repo_url
  1713. else:
  1714. raise ValueError(f'Repo {repo_id} already exists!')
  1715. repo_id_list = repo_id.split('/')
  1716. if len(repo_id_list) != 2:
  1717. raise ValueError('Invalid repo id, should be in the format of `owner_name/repo_name`')
  1718. namespace, repo_name = repo_id_list
  1719. if repo_type == REPO_TYPE_MODEL:
  1720. visibilities = {k: v for k, v in ModelVisibility.__dict__.items() if not k.startswith('__')}
  1721. visibility: int = visibilities.get(visibility.upper())
  1722. if visibility is None:
  1723. raise ValueError(f'Invalid visibility: {visibility}, '
  1724. f'supported visibilities: `public`, `private`, `internal`')
  1725. repo_url: str = self.create_model(
  1726. model_id=repo_id,
  1727. visibility=visibility,
  1728. license=license,
  1729. chinese_name=chinese_name,
  1730. aigc_model=aigc_model
  1731. )
  1732. if create_default_config:
  1733. with tempfile.TemporaryDirectory() as temp_cache_dir:
  1734. from modelscope.hub.repository import Repository
  1735. repo = Repository(temp_cache_dir, repo_id)
  1736. default_config = {
  1737. 'framework': 'pytorch',
  1738. 'task': 'text-generation',
  1739. 'allow_remote': True
  1740. }
  1741. config_json = kwargs.get('config_json')
  1742. if not config_json:
  1743. config_json = {}
  1744. config = {**default_config, **config_json}
  1745. add_content_to_file(
  1746. repo,
  1747. 'configuration.json', [json.dumps(config)],
  1748. ignore_push_error=True)
  1749. print(f'New model created successfully at {repo_url}.', flush=True)
  1750. elif repo_type == REPO_TYPE_DATASET:
  1751. visibilities = {k: v for k, v in DatasetVisibility.__dict__.items() if not k.startswith('__')}
  1752. visibility: int = visibilities.get(visibility.upper())
  1753. if visibility is None:
  1754. raise ValueError(f'Invalid visibility: {visibility}, '
  1755. f'supported visibilities: `public`, `private`, `internal`')
  1756. repo_url: str = self.create_dataset(
  1757. dataset_name=repo_name,
  1758. namespace=namespace,
  1759. chinese_name=chinese_name,
  1760. license=license,
  1761. visibility=visibility,
  1762. )
  1763. print(f'New dataset created successfully at {repo_url}.', flush=True)
  1764. else:
  1765. raise ValueError(f'Invalid repo type: {repo_type}, supported repos: {REPO_TYPE_SUPPORT}')
  1766. return repo_url
  1767. def create_commit(
  1768. self,
  1769. repo_id: str,
  1770. operations: Iterable[CommitOperation],
  1771. *,
  1772. commit_message: str,
  1773. commit_description: Optional[str] = None,
  1774. token: str = None,
  1775. repo_type: Optional[str] = REPO_TYPE_MODEL,
  1776. revision: Optional[str] = DEFAULT_REPOSITORY_REVISION,
  1777. endpoint: Optional[str] = None,
  1778. max_retries: int = 3,
  1779. timeout: int = 180,
  1780. ) -> CommitInfo:
  1781. """
  1782. Create a commit on the ModelScope Hub with retry mechanism.
  1783. Args:
  1784. repo_id (str): The repo id in the format of `owner_name/repo_name`.
  1785. operations (Iterable[CommitOperation]): The commit operations.
  1786. commit_message (str): The commit message.
  1787. commit_description (Optional[str]): The commit description.
  1788. token (str): The access token. If None, will use the cookies from the local cache.
  1789. See `https://modelscope.cn/my/myaccesstoken` to get your token.
  1790. repo_type (Optional[str]): The repo type, should be `model` or `dataset`. Defaults to `model`.
  1791. revision (Optional[str]): The branch or tag name. Defaults to `DEFAULT_REPOSITORY_REVISION`.
  1792. endpoint (Optional[str]): The endpoint to use.
  1793. In the format of `https://www.modelscope.cn` or 'https://www.modelscope.ai'
  1794. max_retries (int): Number of max retry attempts (default: 3).
  1795. timeout (int): Timeout for each request in seconds (default: 180).
  1796. Returns:
  1797. CommitInfo: The commit info.
  1798. Raises:
  1799. requests.exceptions.RequestException: If all retry attempts fail.
  1800. """
  1801. if not repo_id:
  1802. raise ValueError('Repo id cannot be empty!')
  1803. if not endpoint:
  1804. endpoint = self.endpoint
  1805. if repo_type not in REPO_TYPE_SUPPORT:
  1806. raise ValueError(f'Invalid repo type: {repo_type}, supported repos: {REPO_TYPE_SUPPORT}')
  1807. url = f'{endpoint}/api/v1/repos/{repo_type}s/{repo_id}/commit/{revision}'
  1808. commit_message = commit_message or f'Commit to {repo_id}'
  1809. commit_description = commit_description or ''
  1810. cookies = self.get_cookies(access_token=token, cookies_required=True)
  1811. # Construct payload
  1812. payload = self._prepare_commit_payload(
  1813. operations=operations,
  1814. commit_message=commit_message,
  1815. )
  1816. # POST with retry mechanism
  1817. last_exception = None
  1818. for attempt in range(max_retries):
  1819. try:
  1820. if attempt > 0:
  1821. logger.info(f'Attempt {attempt + 1} to create commit for {repo_id}...')
  1822. response = requests.post(
  1823. url,
  1824. headers=self.builder_headers(self.headers),
  1825. data=json.dumps(payload),
  1826. cookies=cookies,
  1827. timeout=timeout,
  1828. )
  1829. if response.status_code != 200:
  1830. try:
  1831. error_detail = response.json()
  1832. except json.JSONDecodeError:
  1833. error_detail = response.text
  1834. error_msg = (
  1835. f'HTTP {response.status_code} error from {url}: '
  1836. f'{error_detail}'
  1837. )
  1838. # If server error (5xx), we can retry, otherwise (4xx) raise immediately
  1839. if 500 <= response.status_code < 600:
  1840. logger.warning(
  1841. f'Server error on attempt {attempt + 1}: {error_msg}'
  1842. )
  1843. else:
  1844. raise ValueError(f'Client request failed: {error_msg}')
  1845. else:
  1846. resp = response.json()
  1847. oid = resp.get('Data', {}).get('oid', '')
  1848. logger.info(f'Commit succeeded: {url}')
  1849. return CommitInfo(
  1850. commit_url=url,
  1851. commit_message=commit_message,
  1852. commit_description=commit_description,
  1853. oid=oid,
  1854. )
  1855. except requests.exceptions.RequestException as e:
  1856. last_exception = e
  1857. logger.warning(f'Request failed on attempt {attempt + 1}: {str(e)}')
  1858. except Exception as e:
  1859. last_exception = e
  1860. logger.error(f'Unexpected error on attempt {attempt + 1}: {str(e)}')
  1861. if attempt == max_retries - 1:
  1862. raise
  1863. if attempt < max_retries - 1:
  1864. time.sleep(1)
  1865. # All retries exhausted
  1866. raise requests.exceptions.RequestException(
  1867. f'Failed to create commit after {max_retries} attempts. Last error: {last_exception}'
  1868. )
  1869. def upload_file(
  1870. self,
  1871. *,
  1872. path_or_fileobj: Union[str, Path, bytes, BinaryIO],
  1873. path_in_repo: str,
  1874. repo_id: str,
  1875. token: Union[str, None] = None,
  1876. repo_type: Optional[str] = REPO_TYPE_MODEL,
  1877. commit_message: Optional[str] = None,
  1878. commit_description: Optional[str] = None,
  1879. buffer_size_mb: Optional[int] = 1,
  1880. tqdm_desc: Optional[str] = '[Uploading]',
  1881. disable_tqdm: Optional[bool] = False,
  1882. revision: Optional[str] = DEFAULT_REPOSITORY_REVISION
  1883. ) -> CommitInfo:
  1884. """
  1885. Upload a file to the ModelScope Hub.
  1886. Args:
  1887. path_or_fileobj (Union[str, Path, bytes, BinaryIO]):
  1888. The local file path or file-like object (BinaryIO) or bytes to upload.
  1889. path_in_repo (str): The path in the repo to upload to.
  1890. repo_id (str): The repo id in the format of `owner_name/repo_name`.
  1891. token (Union[str, None]): The access token. If None, will use the cookies from the local cache.
  1892. See `https://modelscope.cn/my/myaccesstoken` to get your token.
  1893. repo_type (Optional[str]): The repo type, default to `model`.
  1894. commit_message (Optional[str]): The commit message.
  1895. commit_description (Optional[str]): The commit description.
  1896. buffer_size_mb (Optional[int]): The buffer size in MB for reading the file. Default to 1MB.
  1897. tqdm_desc (Optional[str]): The description for the tqdm progress bar. Default to '[Uploading]'.
  1898. disable_tqdm (Optional[bool]): Whether to disable the tqdm progress bar. Default to False.
  1899. revision (Optional[str]): The branch or tag name. Defaults to `DEFAULT_REPOSITORY_REVISION`.
  1900. Returns:
  1901. CommitInfo: The commit info.
  1902. Examples:
  1903. >>> from modelscope.hub.api import HubApi
  1904. >>> api = HubApi()
  1905. >>> commit_info = api.upload_file(
  1906. ... path_or_fileobj='/path/to/your/file.txt',
  1907. ... path_in_repo='optional/path/in/repo/file.txt',
  1908. ... repo_id='your-namespace/your-repo-name',
  1909. ... commit_message='Upload file.txt to ModelScope hub'
  1910. ... )
  1911. >>> print(commit_info)
  1912. """
  1913. if repo_type not in REPO_TYPE_SUPPORT:
  1914. raise ValueError(f'Invalid repo type: {repo_type}, supported repos: {REPO_TYPE_SUPPORT}')
  1915. if not path_or_fileobj:
  1916. raise ValueError('Path or file object cannot be empty!')
  1917. # Check authentication first
  1918. self.get_cookies(access_token=token, cookies_required=True)
  1919. if isinstance(path_or_fileobj, (str, Path)):
  1920. path_or_fileobj = os.path.abspath(os.path.expanduser(path_or_fileobj))
  1921. path_in_repo = path_in_repo or os.path.basename(path_or_fileobj)
  1922. else:
  1923. # If path_or_fileobj is bytes or BinaryIO, then path_in_repo must be provided
  1924. if not path_in_repo:
  1925. raise ValueError('Arg `path_in_repo` cannot be empty!')
  1926. # Read file content if path_or_fileobj is a file-like object (BinaryIO)
  1927. # TODO: to be refined
  1928. if isinstance(path_or_fileobj, io.BufferedIOBase):
  1929. path_or_fileobj = path_or_fileobj.read()
  1930. self.upload_checker.check_file(path_or_fileobj)
  1931. self.upload_checker.check_normal_files(
  1932. file_path_list=[path_or_fileobj],
  1933. repo_type=repo_type,
  1934. )
  1935. commit_message = (
  1936. commit_message if commit_message is not None else f'Upload {path_in_repo} to ModelScope hub'
  1937. )
  1938. if buffer_size_mb <= 0:
  1939. raise ValueError('Buffer size: `buffer_size_mb` must be greater than 0')
  1940. hash_info_d: dict = get_file_hash(
  1941. file_path_or_obj=path_or_fileobj,
  1942. buffer_size_mb=buffer_size_mb,
  1943. )
  1944. file_size: int = hash_info_d['file_size']
  1945. file_hash: str = hash_info_d['file_hash']
  1946. self.create_repo(repo_id=repo_id,
  1947. token=token,
  1948. repo_type=repo_type,
  1949. endpoint=self.endpoint,
  1950. exist_ok=True,
  1951. create_default_config=False)
  1952. upload_res: dict = self._upload_blob(
  1953. repo_id=repo_id,
  1954. repo_type=repo_type,
  1955. sha256=file_hash,
  1956. size=file_size,
  1957. data=path_or_fileobj,
  1958. disable_tqdm=disable_tqdm,
  1959. tqdm_desc=tqdm_desc,
  1960. )
  1961. # Construct commit info and create commit
  1962. add_operation: CommitOperationAdd = CommitOperationAdd(
  1963. path_in_repo=path_in_repo,
  1964. path_or_fileobj=path_or_fileobj,
  1965. file_hash_info=hash_info_d,
  1966. )
  1967. add_operation._upload_mode = 'lfs' if self.upload_checker.is_lfs(path_or_fileobj, repo_type) else 'normal'
  1968. add_operation._is_uploaded = upload_res['is_uploaded']
  1969. operations = [add_operation]
  1970. print(f'Committing file to {repo_id} ...', flush=True)
  1971. commit_info: CommitInfo = self.create_commit(
  1972. repo_id=repo_id,
  1973. operations=operations,
  1974. commit_message=commit_message,
  1975. commit_description=commit_description,
  1976. token=token,
  1977. repo_type=repo_type,
  1978. revision=revision,
  1979. )
  1980. return commit_info
  1981. def upload_folder(
  1982. self,
  1983. *,
  1984. repo_id: str,
  1985. folder_path: Union[str, Path, List[str], List[Path]],
  1986. path_in_repo: Optional[str] = '',
  1987. commit_message: Optional[str] = None,
  1988. commit_description: Optional[str] = None,
  1989. token: Union[str, None] = None,
  1990. repo_type: Optional[str] = REPO_TYPE_MODEL,
  1991. allow_patterns: Optional[Union[List[str], str]] = None,
  1992. ignore_patterns: Optional[Union[List[str], str]] = None,
  1993. max_workers: int = DEFAULT_MAX_WORKERS,
  1994. revision: Optional[str] = DEFAULT_REPOSITORY_REVISION,
  1995. ) -> Union[CommitInfo, List[CommitInfo]]:
  1996. """
  1997. Upload a folder to the ModelScope Hub.
  1998. Args:
  1999. repo_id (str): The repo id in the format of `owner_name/repo_name`.
  2000. folder_path (Union[str, Path, List[str], List[Path]]): The folder path or list of file paths to upload.
  2001. path_in_repo (Optional[str]): The path in the repo to upload to.
  2002. commit_message (Optional[str]): The commit message.
  2003. commit_description (Optional[str]): The commit description.
  2004. token (Union[str, None]): The access token. If None, will use the cookies from the local cache.
  2005. See `https://modelscope.cn/my/myaccesstoken` to get your token.
  2006. repo_type (Optional[str]): The repo type, default to `model`.
  2007. allow_patterns (Optional[Union[List[str], str]]): The patterns to allow.
  2008. ignore_patterns (Optional[Union[List[str], str]]): The patterns to ignore.
  2009. max_workers (int): The maximum number of workers to use for uploading files concurrently.
  2010. Defaults to `DEFAULT_MAX_WORKERS`.
  2011. revision (Optional[str]): The branch or tag name. Defaults to `DEFAULT_REPOSITORY_REVISION`.
  2012. Returns:
  2013. Union[CommitInfo, List[CommitInfo]]:
  2014. The commit info or list of commit infos if multiple batches are committed.
  2015. Examples:
  2016. >>> from modelscope.hub.api import HubApi
  2017. >>> api = HubApi()
  2018. >>> commit_info = api.upload_folder(
  2019. ... repo_id='your-namespace/your-repo-name',
  2020. ... folder_path='/path/to/your/folder',
  2021. ... path_in_repo='optional/path/in/repo',
  2022. ... commit_message='Upload my folder',
  2023. ... token='your-access-token'
  2024. ... )
  2025. >>> print(commit_info.commit_url)
  2026. """
  2027. if not repo_id:
  2028. raise ValueError('The arg `repo_id` cannot be empty!')
  2029. if folder_path is None:
  2030. raise ValueError('The arg `folder_path` cannot be None!')
  2031. if repo_type not in REPO_TYPE_SUPPORT:
  2032. raise ValueError(f'Invalid repo type: {repo_type}, supported repos: {REPO_TYPE_SUPPORT}')
  2033. # Check authentication first
  2034. self.get_cookies(access_token=token, cookies_required=True)
  2035. allow_patterns = allow_patterns if allow_patterns else None
  2036. ignore_patterns = ignore_patterns if ignore_patterns else None
  2037. # Ignore .git .cache folders
  2038. if ignore_patterns is None:
  2039. ignore_patterns = []
  2040. elif isinstance(ignore_patterns, str):
  2041. ignore_patterns = [ignore_patterns]
  2042. ignore_patterns += DEFAULT_IGNORE_PATTERNS
  2043. # Cover the ignore patterns if both allow and ignore patterns are provided
  2044. if allow_patterns is not None:
  2045. if '**' in allow_patterns:
  2046. ignore_patterns = []
  2047. ignore_patterns = [
  2048. p for p in ignore_patterns if p not in allow_patterns
  2049. ]
  2050. commit_message = (
  2051. commit_message if commit_message is not None else f'Upload to {repo_id} on ModelScope hub'
  2052. )
  2053. commit_description = commit_description or 'Uploading files'
  2054. # Get the list of files to upload, e.g. [('data/abc.png', '/path/to/abc.png'), ...]
  2055. logger.info('Preparing files to upload ...')
  2056. prepared_repo_objects = self._prepare_upload_folder(
  2057. folder_path_or_files=folder_path,
  2058. path_in_repo=path_in_repo,
  2059. allow_patterns=allow_patterns,
  2060. ignore_patterns=ignore_patterns,
  2061. )
  2062. if len(prepared_repo_objects) == 0:
  2063. raise ValueError(f'No files to upload in the folder: {folder_path} !')
  2064. logger.info(f'Checking {len(prepared_repo_objects)} files to upload ...')
  2065. self.upload_checker.check_normal_files(
  2066. file_path_list=[item for _, item in prepared_repo_objects],
  2067. repo_type=repo_type,
  2068. )
  2069. self.create_repo(repo_id=repo_id,
  2070. token=token,
  2071. repo_type=repo_type,
  2072. endpoint=self.endpoint,
  2073. exist_ok=True,
  2074. create_default_config=False)
  2075. @thread_executor(max_workers=max_workers, disable_tqdm=False)
  2076. def _upload_items(item_pair, **kwargs):
  2077. file_path_in_repo, file_path = item_pair
  2078. hash_info_d: dict = get_file_hash(
  2079. file_path_or_obj=file_path,
  2080. )
  2081. file_size: int = hash_info_d['file_size']
  2082. file_hash: str = hash_info_d['file_hash']
  2083. upload_res: dict = self._upload_blob(
  2084. repo_id=repo_id,
  2085. repo_type=repo_type,
  2086. sha256=file_hash,
  2087. size=file_size,
  2088. data=file_path,
  2089. disable_tqdm=file_size <= UPLOAD_BLOB_TQDM_DISABLE_THRESHOLD,
  2090. tqdm_desc='[Uploading ' + file_path_in_repo + ']',
  2091. )
  2092. return {
  2093. 'file_path_in_repo': file_path_in_repo,
  2094. 'file_path': file_path,
  2095. 'is_uploaded': upload_res['is_uploaded'],
  2096. 'file_hash_info': hash_info_d,
  2097. }
  2098. uploaded_items_list = _upload_items(
  2099. prepared_repo_objects,
  2100. repo_id=repo_id,
  2101. token=token,
  2102. repo_type=repo_type,
  2103. commit_message=commit_message,
  2104. commit_description=commit_description,
  2105. buffer_size_mb=1,
  2106. disable_tqdm=False,
  2107. )
  2108. # Construct commit info and create commit
  2109. operations = []
  2110. for item_d in uploaded_items_list:
  2111. prepared_path_in_repo: str = item_d['file_path_in_repo']
  2112. prepared_file_path: str = item_d['file_path']
  2113. is_uploaded: bool = item_d['is_uploaded']
  2114. file_hash_info: dict = item_d['file_hash_info']
  2115. opt = CommitOperationAdd(
  2116. path_in_repo=prepared_path_in_repo,
  2117. path_or_fileobj=prepared_file_path,
  2118. file_hash_info=file_hash_info,
  2119. )
  2120. # check normal or lfs
  2121. opt._upload_mode = 'lfs' if self.upload_checker.is_lfs(prepared_file_path, repo_type) else 'normal'
  2122. opt._is_uploaded = is_uploaded
  2123. operations.append(opt)
  2124. if len(operations) == 0:
  2125. raise ValueError(f'No files to upload in the folder: {folder_path} !')
  2126. # Commit the operations in batches
  2127. commit_batch_size: int = UPLOAD_COMMIT_BATCH_SIZE if UPLOAD_COMMIT_BATCH_SIZE > 0 else len(operations)
  2128. num_batches = (len(operations) - 1) // commit_batch_size + 1
  2129. print(f'Committing {len(operations)} files in {num_batches} batch(es) of size {commit_batch_size}.',
  2130. flush=True)
  2131. commit_infos: List[CommitInfo] = []
  2132. for i in tqdm(range(num_batches), desc='[Committing batches] ', total=num_batches):
  2133. batch_operations = operations[i * commit_batch_size: (i + 1) * commit_batch_size]
  2134. batch_commit_message = f'{commit_message} (batch {i + 1}/{num_batches})'
  2135. commit_info: CommitInfo = self.create_commit(
  2136. repo_id=repo_id,
  2137. operations=batch_operations,
  2138. commit_message=batch_commit_message,
  2139. commit_description=commit_description,
  2140. token=token,
  2141. repo_type=repo_type,
  2142. revision=revision,
  2143. )
  2144. commit_infos.append(commit_info)
  2145. return commit_infos[0] if len(commit_infos) == 1 else commit_infos
  2146. def _upload_blob(
  2147. self,
  2148. *,
  2149. repo_id: str,
  2150. repo_type: str,
  2151. sha256: str,
  2152. size: int,
  2153. data: Union[str, Path, bytes, BinaryIO],
  2154. disable_tqdm: Optional[bool] = False,
  2155. tqdm_desc: Optional[str] = '[Uploading]',
  2156. buffer_size_mb: Optional[int] = 1,
  2157. ) -> dict:
  2158. res_d: dict = dict(
  2159. url=None,
  2160. is_uploaded=False,
  2161. status_code=None,
  2162. status_msg=None,
  2163. )
  2164. objects = [{'oid': sha256, 'size': size}]
  2165. upload_objects = self._validate_blob(
  2166. repo_id=repo_id,
  2167. repo_type=repo_type,
  2168. objects=objects,
  2169. )
  2170. # upload_object: {'url': 'xxx', 'oid': 'xxx'}
  2171. upload_object = upload_objects[0] if len(upload_objects) == 1 else None
  2172. if upload_object is None:
  2173. logger.debug(f'Blob {sha256[:8]} has already uploaded, reuse it.')
  2174. res_d['is_uploaded'] = True
  2175. return res_d
  2176. cookies = ModelScopeConfig.get_cookies()
  2177. cookies = dict(cookies) if cookies else None
  2178. if cookies is None:
  2179. raise ValueError('Token does not exist, please login first.')
  2180. self.headers.update({'Cookie': f"m_session_id={cookies['m_session_id']}"})
  2181. headers = self.builder_headers(self.headers)
  2182. def read_in_chunks(file_object, pbar, chunk_size=buffer_size_mb * 1024 * 1024):
  2183. """Lazy function (generator) to read a file piece by piece."""
  2184. while True:
  2185. ck = file_object.read(chunk_size)
  2186. if not ck:
  2187. break
  2188. pbar.update(len(ck))
  2189. yield ck
  2190. with tqdm(
  2191. total=size,
  2192. unit='B',
  2193. unit_scale=True,
  2194. desc=tqdm_desc,
  2195. disable=disable_tqdm
  2196. ) as pbar:
  2197. if isinstance(data, (str, Path)):
  2198. with open(data, 'rb') as f:
  2199. response = requests.put(
  2200. upload_object['url'],
  2201. headers=headers,
  2202. data=read_in_chunks(f, pbar)
  2203. )
  2204. elif isinstance(data, bytes):
  2205. response = requests.put(
  2206. upload_object['url'],
  2207. headers=headers,
  2208. data=read_in_chunks(io.BytesIO(data), pbar)
  2209. )
  2210. elif isinstance(data, io.BufferedIOBase):
  2211. response = requests.put(
  2212. upload_object['url'],
  2213. headers=headers,
  2214. data=read_in_chunks(data, pbar)
  2215. )
  2216. else:
  2217. raise ValueError('Invalid data type to upload')
  2218. raise_for_http_status(rsp=response)
  2219. resp = response.json()
  2220. raise_on_error(rsp=resp)
  2221. res_d['url'] = upload_object['url']
  2222. res_d['status_code'] = resp['Code']
  2223. res_d['status_msg'] = resp['Message']
  2224. return res_d
  2225. def _validate_blob(
  2226. self,
  2227. *,
  2228. repo_id: str,
  2229. repo_type: str,
  2230. objects: List[Dict[str, Any]],
  2231. endpoint: Optional[str] = None
  2232. ) -> List[Dict[str, Any]]:
  2233. """
  2234. Check the blob has already uploaded.
  2235. True -- uploaded; False -- not uploaded.
  2236. Args:
  2237. repo_id (str): The repo id ModelScope.
  2238. repo_type (str): The repo type. `dataset`, `model`, etc.
  2239. objects (List[Dict[str, Any]]): The objects to check.
  2240. oid (str): The sha256 hash value.
  2241. size (int): The size of the blob.
  2242. endpoint: the endpoint to use, default to None to use endpoint specified in the class
  2243. Returns:
  2244. List[Dict[str, Any]]: The result of the check.
  2245. """
  2246. # construct URL
  2247. if not endpoint:
  2248. endpoint = self.endpoint
  2249. url = f'{endpoint}/api/v1/repos/{repo_type}s/{repo_id}/info/lfs/objects/batch'
  2250. # build payload
  2251. payload = {
  2252. 'operation': 'upload',
  2253. 'objects': objects,
  2254. }
  2255. cookies = ModelScopeConfig.get_cookies()
  2256. if cookies is None:
  2257. raise ValueError('Token does not exist, please login first.')
  2258. response = requests.post(
  2259. url,
  2260. headers=self.builder_headers(self.headers),
  2261. data=json.dumps(payload),
  2262. cookies=cookies
  2263. )
  2264. raise_for_http_status(rsp=response)
  2265. resp = response.json()
  2266. raise_on_error(rsp=resp)
  2267. upload_objects = [] # list of objects to upload, [{'url': 'xxx', 'oid': 'xxx'}, ...]
  2268. resp_objects = resp['Data']['objects']
  2269. for obj in resp_objects:
  2270. upload_objects.append(
  2271. {'url': obj['actions']['upload']['href'],
  2272. 'oid': obj['oid']}
  2273. )
  2274. return upload_objects
  2275. def _prepare_upload_folder(
  2276. self,
  2277. folder_path_or_files: Union[str, Path, List[str], List[Path]],
  2278. path_in_repo: str,
  2279. allow_patterns: Optional[Union[List[str], str]] = None,
  2280. ignore_patterns: Optional[Union[List[str], str]] = None,
  2281. ) -> List[Union[tuple, list]]:
  2282. folder_path = None
  2283. files_path = None
  2284. if isinstance(folder_path_or_files, list):
  2285. if os.path.isfile(folder_path_or_files[0]):
  2286. files_path = folder_path_or_files
  2287. else:
  2288. raise ValueError('Uploading multiple folders is not supported now.')
  2289. else:
  2290. if os.path.isfile(folder_path_or_files):
  2291. files_path = [folder_path_or_files]
  2292. else:
  2293. folder_path = folder_path_or_files
  2294. if files_path is None:
  2295. self.upload_checker.check_folder(folder_path)
  2296. folder_path = Path(folder_path).expanduser().resolve()
  2297. if not folder_path.is_dir():
  2298. raise ValueError(f"Provided path: '{folder_path}' is not a directory")
  2299. # List files from folder
  2300. relpath_to_abspath = {
  2301. path.relative_to(folder_path).as_posix(): path
  2302. for path in sorted(folder_path.glob('**/*')) # sorted to be deterministic
  2303. if path.is_file()
  2304. }
  2305. else:
  2306. relpath_to_abspath = {}
  2307. for path in files_path:
  2308. if os.path.isfile(path):
  2309. self.upload_checker.check_file(path)
  2310. relpath_to_abspath[os.path.basename(path)] = path
  2311. # Filter files
  2312. filtered_repo_objects = list(
  2313. RepoUtils.filter_repo_objects(
  2314. relpath_to_abspath.keys(), allow_patterns=allow_patterns, ignore_patterns=ignore_patterns
  2315. )
  2316. )
  2317. prefix = f"{path_in_repo.strip('/')}/" if path_in_repo else ''
  2318. prepared_repo_objects = [
  2319. (prefix + relpath, str(relpath_to_abspath[relpath]))
  2320. for relpath in filtered_repo_objects
  2321. ]
  2322. logger.info(f'Prepared {len(prepared_repo_objects)} files for upload.')
  2323. return prepared_repo_objects
  2324. @staticmethod
  2325. def _prepare_commit_payload(
  2326. operations: Iterable[CommitOperation],
  2327. commit_message: str,
  2328. ) -> Dict[str, Any]:
  2329. """
  2330. Prepare the commit payload to be sent to the ModelScope hub.
  2331. """
  2332. payload = {
  2333. 'commit_message': commit_message,
  2334. 'actions': []
  2335. }
  2336. nb_ignored_files = 0
  2337. # 2. Send operations, one per line
  2338. for operation in operations:
  2339. # Skip ignored files
  2340. if isinstance(operation, CommitOperationAdd) and operation._should_ignore:
  2341. logger.debug(f"Skipping file '{operation.path_in_repo}' in commit (ignored by gitignore file).")
  2342. nb_ignored_files += 1
  2343. continue
  2344. # 2.a. Case adding a normal file
  2345. if isinstance(operation, CommitOperationAdd) and operation._upload_mode == 'normal':
  2346. commit_action = {
  2347. 'action': 'update' if operation._is_uploaded else 'create',
  2348. 'path': operation.path_in_repo,
  2349. 'type': 'normal',
  2350. 'size': operation.upload_info.size,
  2351. 'sha256': '',
  2352. 'content': operation.b64content().decode(),
  2353. 'encoding': 'base64',
  2354. }
  2355. payload['actions'].append(commit_action)
  2356. # 2.b. Case adding an LFS file
  2357. elif isinstance(operation, CommitOperationAdd) and operation._upload_mode == 'lfs':
  2358. commit_action = {
  2359. 'action': 'update' if operation._is_uploaded else 'create',
  2360. 'path': operation.path_in_repo,
  2361. 'type': 'lfs',
  2362. 'size': operation.upload_info.size,
  2363. 'sha256': operation.upload_info.sha256,
  2364. 'content': '',
  2365. 'encoding': '',
  2366. }
  2367. payload['actions'].append(commit_action)
  2368. else:
  2369. raise ValueError(
  2370. f'Unknown operation to commit. Operation: {operation}. Upload mode:'
  2371. f" {getattr(operation, '_upload_mode', None)}"
  2372. )
  2373. if nb_ignored_files > 0:
  2374. logger.info(f'Skipped {nb_ignored_files} file(s) in commit (ignored by gitignore file).')
  2375. return payload
  2376. def _get_internal_acceleration_domain(self, internal_timeout: float = 0.2):
  2377. """
  2378. Get the internal acceleration domain.
  2379. Args:
  2380. internal_timeout (float): The timeout for the request. Default to 0.2s
  2381. Returns:
  2382. str: The internal acceleration domain. e.g. `cn-hangzhou`, `cn-zhangjiakou`
  2383. """
  2384. def send_request(url: str, timeout: float):
  2385. try:
  2386. response = requests.get(url, timeout=timeout)
  2387. response.raise_for_status()
  2388. except requests.exceptions.RequestException:
  2389. response = None
  2390. return response
  2391. internal_url = f'{self.endpoint}/api/v1/repos/internalAccelerationInfo'
  2392. # Get internal url and region for acceleration
  2393. internal_info_response = send_request(url=internal_url, timeout=internal_timeout)
  2394. region_id: str = ''
  2395. if internal_info_response is not None:
  2396. internal_info_response = internal_info_response.json()
  2397. if 'Data' in internal_info_response:
  2398. query_addr = internal_info_response['Data']['InternalRegionQueryAddress']
  2399. else:
  2400. query_addr: str = ''
  2401. if query_addr:
  2402. domain_response = send_request(query_addr, timeout=internal_timeout)
  2403. if domain_response is not None:
  2404. region_id = domain_response.text.strip()
  2405. return region_id
  2406. def delete_files(self,
  2407. repo_id: str,
  2408. repo_type: str,
  2409. delete_patterns: Union[str, List[str]],
  2410. *,
  2411. revision: Optional[str] = DEFAULT_MODEL_REVISION,
  2412. endpoint: Optional[str] = None) -> Dict[str, Any]:
  2413. """
  2414. Delete files in batch using glob (wildcard) patterns, e.g. '*.py', 'data/*.csv', 'foo*', etc.
  2415. Example:
  2416. # Delete all Python and Markdown files in a model repo
  2417. api.delete_files(
  2418. repo_id='your_username/your_model',
  2419. repo_type=REPO_TYPE_MODEL,
  2420. delete_patterns=['*.py', '*.md']
  2421. )
  2422. # Delete all CSV files in the data/ directory of a dataset repo
  2423. api.delete_files(
  2424. repo_id='your_username/your_dataset',
  2425. repo_type=REPO_TYPE_DATASET,
  2426. delete_patterns='data/*.csv'
  2427. )
  2428. Args:
  2429. repo_id (str): 'owner/repo_name' or 'owner/dataset_name', e.g. 'Koko/my_model'
  2430. repo_type (str): REPO_TYPE_MODEL or REPO_TYPE_DATASET
  2431. delete_patterns (str or List[str]): List of glob patterns, e.g. '*.py', 'data/*.csv', 'foo*'
  2432. revision (str, optional): Branch or tag name
  2433. endpoint (str, optional): API endpoint
  2434. Returns:
  2435. dict: Deletion result
  2436. """
  2437. if repo_type not in REPO_TYPE_SUPPORT:
  2438. raise ValueError(f'Unsupported repo_type: {repo_type}')
  2439. if not delete_patterns:
  2440. raise ValueError('delete_patterns cannot be empty')
  2441. if isinstance(delete_patterns, str):
  2442. delete_patterns = [delete_patterns]
  2443. cookies = ModelScopeConfig.get_cookies()
  2444. if not endpoint:
  2445. endpoint = self.endpoint
  2446. if cookies is None:
  2447. raise ValueError('Token does not exist, please login first.')
  2448. headers = self.builder_headers(self.headers)
  2449. # List all files in the repo
  2450. if repo_type == REPO_TYPE_MODEL:
  2451. files = self.get_model_files(
  2452. repo_id,
  2453. revision=revision or DEFAULT_MODEL_REVISION,
  2454. recursive=True,
  2455. endpoint=endpoint,
  2456. use_cookies=cookies,
  2457. )
  2458. file_paths = [f['Path'] for f in files]
  2459. elif repo_type == REPO_TYPE_DATASET:
  2460. file_paths = []
  2461. page_number = 1
  2462. page_size = 100
  2463. while True:
  2464. try:
  2465. dataset_files: List[Dict[str, Any]] = self.get_dataset_files(
  2466. repo_id=repo_id,
  2467. revision=revision or DEFAULT_DATASET_REVISION,
  2468. recursive=True,
  2469. page_number=page_number,
  2470. page_size=page_size,
  2471. endpoint=endpoint,
  2472. )
  2473. except Exception as e:
  2474. logger.error(f'Get dataset: {repo_id} file list failed, message: {str(e)}')
  2475. break
  2476. # Parse data (Type: 'tree' or 'blob')
  2477. for file_info_d in dataset_files:
  2478. if file_info_d['Type'] != 'tree':
  2479. file_paths.append(file_info_d['Path'])
  2480. if len(dataset_files) < page_size:
  2481. break
  2482. page_number += 1
  2483. else:
  2484. raise ValueError(f'Unsupported repo_type: {repo_type}, supported repos: {REPO_TYPE_SUPPORT}')
  2485. # Glob pattern matching
  2486. to_delete = []
  2487. for path in file_paths:
  2488. for delete_pattern in delete_patterns:
  2489. if fnmatch.fnmatch(path, delete_pattern):
  2490. to_delete.append(path)
  2491. break
  2492. deleted_files, failed_files = [], []
  2493. for path in to_delete:
  2494. try:
  2495. if repo_type == REPO_TYPE_MODEL:
  2496. owner, repo_name = repo_id.split('/')
  2497. url = f'{endpoint}/api/v1/models/{owner}/{repo_name}/file'
  2498. params = {
  2499. 'Revision': revision or DEFAULT_MODEL_REVISION,
  2500. 'FilePath': path
  2501. }
  2502. elif repo_type == REPO_TYPE_DATASET:
  2503. owner, dataset_name = repo_id.split('/')
  2504. url = f'{endpoint}/api/v1/datasets/{owner}/{dataset_name}/repo'
  2505. params = {
  2506. 'FilePath': path
  2507. }
  2508. else:
  2509. raise ValueError(f'Unsupported repo_type: {repo_type}, supported repos: {REPO_TYPE_SUPPORT}')
  2510. r = self.session.delete(url, params=params, cookies=cookies, headers=headers)
  2511. raise_for_http_status(r)
  2512. resp = r.json()
  2513. raise_on_error(resp)
  2514. deleted_files.append(path)
  2515. except Exception as e:
  2516. failed_files.append(path)
  2517. logger.error(f'Failed to delete {path}: {str(e)}')
  2518. return {
  2519. 'deleted_files': deleted_files,
  2520. 'failed_files': failed_files,
  2521. 'total_files': len(to_delete)
  2522. }
  2523. def set_repo_visibility(self,
  2524. repo_id: str,
  2525. repo_type: Literal['model', 'dataset'],
  2526. visibility: Literal['private', 'public'],
  2527. token: Union[str, None] = None
  2528. ) -> dict:
  2529. """
  2530. Set the visibility of a repo.
  2531. Args:
  2532. repo_id (str): The repo id in the format of `owner_name/repo_name`.
  2533. repo_type (Literal['model', 'dataset']): The repo type, `model` or `dataset`.
  2534. visibility (Literal['private', 'public']): The visibility to set, `private` or `public`.
  2535. token (Union[str, None]): The access token. If None, will use the cookies from the local cache.
  2536. See `https://modelscope.cn/my/myaccesstoken` to get your token.
  2537. Returns:
  2538. dict: The response from the server.
  2539. """
  2540. if not repo_id:
  2541. raise ValueError('The arg `repo_id` cannot be empty!')
  2542. if visibility not in ['private', 'public']:
  2543. raise ValueError(f'Invalid visibility: {visibility}, supported visibilities: `private`, `public`')
  2544. visibility_map: Dict[str, int] = {v: k for k, v in VisibilityMap.items()}
  2545. visibility_code: int = visibility_map.get(visibility, 5)
  2546. cookies = self.get_cookies(access_token=token, cookies_required=True)
  2547. if repo_type == REPO_TYPE_MODEL:
  2548. model_info = self.get_model(model_id=repo_id)
  2549. path = f'{self.endpoint}/api/v1/models/{repo_id}'
  2550. tasks = model_info.get('Tasks')
  2551. model_tasks = ''
  2552. if isinstance(tasks, list) and tasks:
  2553. first = tasks[0]
  2554. if isinstance(first, dict) and first:
  2555. model_tasks = first.get('name')
  2556. payload = {
  2557. 'ChineseName': model_info.get('ChineseName', ''),
  2558. 'ModelFramework': model_info.get('ModelFramework', 'Pytorch'),
  2559. 'Visibility': visibility_code,
  2560. 'ProtectedMode': 2,
  2561. 'ApprovalMode': model_info.get('ApprovalMode', 2),
  2562. 'Description': model_info.get('Description', ''),
  2563. 'AigcType': model_info.get('AigcType', ''),
  2564. 'VisionFoundation': model_info.get('VisionFoundation', ''),
  2565. 'ModelCover': model_info.get('ModelCover', ''),
  2566. 'SubScientificField': model_info.get('SubScientificField', None),
  2567. 'ScientificField': model_info.get('NEXA', {}).get('ScientificField', ''),
  2568. 'Source': model_info.get('NEXA', {}).get('Source', ''),
  2569. 'ModelTask': model_tasks,
  2570. 'License': model_info.get('License', ''),
  2571. }
  2572. elif repo_type == REPO_TYPE_DATASET:
  2573. repo_id_parts = repo_id.split('/')
  2574. if len(repo_id_parts) != 2 or not all(repo_id_parts):
  2575. raise ValueError(f'Invalid dataset repo_id: {repo_id}, should be in format of `owner/dataset_name`')
  2576. dataset_idx, _ = self.get_dataset_id_and_type(
  2577. dataset_name=repo_id_parts[1],
  2578. namespace=repo_id_parts[0],
  2579. )
  2580. path = f'{self.endpoint}/api/v1/datasets/{dataset_idx}'
  2581. payload = {
  2582. 'Visibility': visibility_code,
  2583. 'ProtectedMode': 2,
  2584. }
  2585. else:
  2586. raise ValueError(f'Invalid repo type: {repo_type}, supported repos: {REPO_TYPE_SUPPORT}')
  2587. r = self.session.put(
  2588. path,
  2589. json=payload,
  2590. cookies=cookies,
  2591. headers=self.builder_headers(self.headers))
  2592. raise_for_http_status(r)
  2593. resp = r.json()
  2594. raise_on_error(resp)
  2595. return resp
  2596. class ModelScopeConfig:
  2597. path_credential = expanduser(MODELSCOPE_CREDENTIALS_PATH)
  2598. COOKIES_FILE_NAME = 'cookies'
  2599. GIT_TOKEN_FILE_NAME = 'git_token'
  2600. USER_INFO_FILE_NAME = 'user'
  2601. USER_SESSION_ID_FILE_NAME = 'session'
  2602. cookie_expired_warning = False
  2603. @staticmethod
  2604. def make_sure_credential_path_exist():
  2605. os.makedirs(ModelScopeConfig.path_credential, exist_ok=True)
  2606. @staticmethod
  2607. def save_cookies(cookies: CookieJar):
  2608. ModelScopeConfig.make_sure_credential_path_exist()
  2609. with open(
  2610. os.path.join(ModelScopeConfig.path_credential,
  2611. ModelScopeConfig.COOKIES_FILE_NAME), 'wb+') as f:
  2612. pickle.dump(cookies, f)
  2613. @staticmethod
  2614. def get_cookies():
  2615. cookies_path = os.path.join(ModelScopeConfig.path_credential,
  2616. ModelScopeConfig.COOKIES_FILE_NAME)
  2617. if os.path.exists(cookies_path):
  2618. with open(cookies_path, 'rb') as f:
  2619. cookies = pickle.load(f)
  2620. for cookie in cookies:
  2621. if cookie.name == 'm_session_id' and cookie.is_expired() and \
  2622. not ModelScopeConfig.cookie_expired_warning:
  2623. ModelScopeConfig.cookie_expired_warning = True
  2624. logger.info('Not logged-in, you can login for uploading'
  2625. 'or accessing controlled entities.')
  2626. return None
  2627. return cookies
  2628. return None
  2629. @staticmethod
  2630. def get_user_session_id():
  2631. session_path = os.path.join(ModelScopeConfig.path_credential,
  2632. ModelScopeConfig.USER_SESSION_ID_FILE_NAME)
  2633. session_id = ''
  2634. if os.path.exists(session_path):
  2635. with open(session_path, 'rb') as f:
  2636. session_id = str(f.readline().strip(), encoding='utf-8')
  2637. return session_id
  2638. if session_id == '' or len(session_id) != 32:
  2639. session_id = str(uuid.uuid4().hex)
  2640. ModelScopeConfig.make_sure_credential_path_exist()
  2641. with open(session_path, 'w+') as wf:
  2642. wf.write(session_id)
  2643. return session_id
  2644. @staticmethod
  2645. def save_token(token: str):
  2646. ModelScopeConfig.make_sure_credential_path_exist()
  2647. with open(
  2648. os.path.join(ModelScopeConfig.path_credential,
  2649. ModelScopeConfig.GIT_TOKEN_FILE_NAME), 'w+') as f:
  2650. f.write(token)
  2651. @staticmethod
  2652. def save_user_info(user_name: str, user_email: str):
  2653. ModelScopeConfig.make_sure_credential_path_exist()
  2654. with open(
  2655. os.path.join(ModelScopeConfig.path_credential,
  2656. ModelScopeConfig.USER_INFO_FILE_NAME), 'w+') as f:
  2657. f.write('%s:%s' % (user_name, user_email))
  2658. @staticmethod
  2659. def get_user_info() -> Tuple[str, str]:
  2660. try:
  2661. with open(
  2662. os.path.join(ModelScopeConfig.path_credential,
  2663. ModelScopeConfig.USER_INFO_FILE_NAME),
  2664. 'r',
  2665. encoding='utf-8') as f:
  2666. info = f.read()
  2667. return info.split(':')[0], info.split(':')[1]
  2668. except FileNotFoundError:
  2669. pass
  2670. return None, None
  2671. @staticmethod
  2672. def get_token() -> Optional[str]:
  2673. """
  2674. Get token or None if not existent.
  2675. Returns:
  2676. `str` or `None`: The token, `None` if it doesn't exist.
  2677. """
  2678. token = None
  2679. try:
  2680. with open(
  2681. os.path.join(ModelScopeConfig.path_credential,
  2682. ModelScopeConfig.GIT_TOKEN_FILE_NAME),
  2683. 'r',
  2684. encoding='utf-8') as f:
  2685. token = f.read()
  2686. except FileNotFoundError:
  2687. pass
  2688. return token
  2689. @staticmethod
  2690. def get_user_agent(user_agent: Union[Dict, str, None] = None, ) -> str:
  2691. """Formats a user-agent string with basic info about a request.
  2692. Args:
  2693. user_agent (`str`, `dict`, *optional*):
  2694. The user agent info in the form of a dictionary or a single string.
  2695. Returns:
  2696. The formatted user-agent string.
  2697. """
  2698. # include some more telemetrics when executing in dedicated
  2699. # cloud containers
  2700. env = 'custom'
  2701. if MODELSCOPE_CLOUD_ENVIRONMENT in os.environ:
  2702. env = os.environ[MODELSCOPE_CLOUD_ENVIRONMENT]
  2703. user_name = 'unknown'
  2704. if MODELSCOPE_CLOUD_USERNAME in os.environ:
  2705. user_name = os.environ[MODELSCOPE_CLOUD_USERNAME]
  2706. from modelscope import __version__
  2707. ua = 'modelscope/%s; python/%s; session_id/%s; platform/%s; processor/%s; env/%s; user/%s' % (
  2708. __version__,
  2709. platform.python_version(),
  2710. ModelScopeConfig.get_user_session_id(),
  2711. platform.platform(),
  2712. platform.processor(),
  2713. env,
  2714. user_name,
  2715. )
  2716. if isinstance(user_agent, dict):
  2717. ua += '; ' + '; '.join(f'{k}/{v}' for k, v in user_agent.items())
  2718. elif isinstance(user_agent, str):
  2719. ua += '; ' + user_agent
  2720. return ua
  2721. class UploadingCheck:
  2722. """
  2723. Check the files and folders to be uploaded.
  2724. Args:
  2725. max_file_count (int): The maximum number of files to be uploaded. Default to `UPLOAD_MAX_FILE_COUNT`.
  2726. max_file_count_in_dir (int): The maximum number of files in a directory.
  2727. Default to `UPLOAD_MAX_FILE_COUNT_IN_DIR`.
  2728. max_file_size (int): The maximum size of a single file in bytes. Default to `UPLOAD_MAX_FILE_SIZE`.
  2729. size_threshold_to_enforce_lfs (int): The size threshold to enforce LFS in bytes.
  2730. Files larger than this size will be enforced to be uploaded via LFS.
  2731. Default to `UPLOAD_SIZE_THRESHOLD_TO_ENFORCE_LFS`.
  2732. normal_file_size_total_limit (int): The total size limit of normal files in bytes.
  2733. Default to `UPLOAD_NORMAL_FILE_SIZE_TOTAL_LIMIT`.
  2734. Examples:
  2735. >>> from modelscope.hub.api import UploadingCheck
  2736. >>> upload_checker = UploadingCheck()
  2737. >>> upload_checker.check_file('/path/to/your/file.txt')
  2738. >>> upload_checker.check_folder('/path/to/your/folder')
  2739. >>> is_lfs = upload_checker.is_lfs('/path/to/your/file.txt', repo_type='model')
  2740. >>> print(f'Is LFS: {is_lfs}')
  2741. """
  2742. def __init__(
  2743. self,
  2744. max_file_count: int = UPLOAD_MAX_FILE_COUNT,
  2745. max_file_count_in_dir: int = UPLOAD_MAX_FILE_COUNT_IN_DIR,
  2746. max_file_size: int = UPLOAD_MAX_FILE_SIZE,
  2747. size_threshold_to_enforce_lfs: int = UPLOAD_SIZE_THRESHOLD_TO_ENFORCE_LFS,
  2748. normal_file_size_total_limit: int = UPLOAD_NORMAL_FILE_SIZE_TOTAL_LIMIT,
  2749. ):
  2750. self.max_file_count = max_file_count
  2751. self.max_file_count_in_dir = max_file_count_in_dir
  2752. self.max_file_size = max_file_size
  2753. self.size_threshold_to_enforce_lfs = size_threshold_to_enforce_lfs
  2754. self.normal_file_size_total_limit = normal_file_size_total_limit
  2755. def check_file(self, file_path_or_obj) -> None:
  2756. """
  2757. Check a single file to be uploaded.
  2758. Args:
  2759. file_path_or_obj (Union[str, Path, bytes, BinaryIO]): The file path or file-like object to be checked.
  2760. Raises:
  2761. ValueError: If the file does not exist or exceeds the size limit.
  2762. """
  2763. if isinstance(file_path_or_obj, (str, Path)):
  2764. if not os.path.exists(file_path_or_obj):
  2765. raise ValueError(f'File {file_path_or_obj} does not exist')
  2766. file_size: int = get_file_size(file_path_or_obj)
  2767. if file_size > self.max_file_size:
  2768. logger.warning(f'File exceeds size limit: {self.max_file_size / (1024 ** 3)} GB, '
  2769. f'got {round(file_size / (1024 ** 3), 4)} GB')
  2770. def check_folder(self, folder_path: Union[str, Path]):
  2771. """
  2772. Check a folder to be uploaded.
  2773. Args:
  2774. folder_path (Union[str, Path]): The folder path to be checked.
  2775. Raises:
  2776. ValueError: If the folder does not exist or exceeds the file count limit.
  2777. """
  2778. file_count = 0
  2779. dir_count = 0
  2780. if isinstance(folder_path, str):
  2781. folder_path = Path(folder_path)
  2782. for item in folder_path.iterdir():
  2783. if item.is_file():
  2784. file_count += 1
  2785. item_size: int = get_file_size(item)
  2786. if item_size > self.max_file_size:
  2787. logger.warning(f'File {item} exceeds size limit: {self.max_file_size / (1024 ** 3)} GB',
  2788. f'got {round(item_size / (1024 ** 3), 4)} GB')
  2789. elif item.is_dir():
  2790. dir_count += 1
  2791. # Count items in subdirectories recursively
  2792. sub_file_count, sub_dir_count = self.check_folder(item)
  2793. if (sub_file_count + sub_dir_count) > self.max_file_count_in_dir:
  2794. raise ValueError(f'Directory {item} contains {sub_file_count + sub_dir_count} items '
  2795. f'and exceeds limit: {self.max_file_count_in_dir}')
  2796. file_count += sub_file_count
  2797. dir_count += sub_dir_count
  2798. if file_count > self.max_file_count:
  2799. raise ValueError(f'Total file count {file_count} and exceeds limit: {self.max_file_count}')
  2800. return file_count, dir_count
  2801. def is_lfs(self, file_path_or_obj: Union[str, Path, bytes, BinaryIO], repo_type: str) -> bool:
  2802. """
  2803. Check if a file should be uploaded via LFS.
  2804. Args:
  2805. file_path_or_obj (Union[str, Path, bytes, BinaryIO]): The file path or file-like object to be checked.
  2806. repo_type (str): The repo type, either `model` or `dataset`.
  2807. Returns:
  2808. bool: True if the file should be uploaded via LFS, False otherwise.
  2809. """
  2810. hit_lfs_suffix = True
  2811. if isinstance(file_path_or_obj, (str, Path)):
  2812. file_path_or_obj = Path(file_path_or_obj)
  2813. if not file_path_or_obj.exists():
  2814. raise ValueError(f'File {file_path_or_obj} does not exist')
  2815. if repo_type == REPO_TYPE_MODEL:
  2816. if file_path_or_obj.suffix not in MODEL_LFS_SUFFIX:
  2817. hit_lfs_suffix = False
  2818. elif repo_type == REPO_TYPE_DATASET:
  2819. if file_path_or_obj.suffix not in DATASET_LFS_SUFFIX:
  2820. hit_lfs_suffix = False
  2821. else:
  2822. raise ValueError(f'Invalid repo type: {repo_type}, supported repos: {REPO_TYPE_SUPPORT}')
  2823. file_size: int = get_file_size(file_path_or_obj)
  2824. return file_size > self.size_threshold_to_enforce_lfs or hit_lfs_suffix
  2825. def check_normal_files(self, file_path_list: List[Union[str, Path]], repo_type: str) -> None:
  2826. """
  2827. Check a list of normal files to be uploaded.
  2828. Args:
  2829. file_path_list (List[Union[str, Path]]): The list of file paths to be checked.
  2830. repo_type (str): The repo type, either `model` or `dataset`.
  2831. Raises:
  2832. ValueError: If the total size of normal files exceeds the limit.
  2833. Returns: None
  2834. """
  2835. normal_file_list = [item for item in file_path_list if not self.is_lfs(item, repo_type)]
  2836. total_size = sum([get_file_size(item) for item in normal_file_list])
  2837. if total_size > self.normal_file_size_total_limit:
  2838. raise ValueError(f'Total size of non-lfs files {total_size / (1024 * 1024)}MB '
  2839. f'and exceeds limit: {self.normal_file_size_total_limit / (1024 * 1024)}MB')