patcher.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import contextlib
  3. import importlib
  4. import inspect
  5. import os
  6. import re
  7. import sys
  8. from asyncio import Future
  9. from functools import partial
  10. from pathlib import Path
  11. from types import MethodType
  12. from typing import BinaryIO, Dict, Iterable, List, Optional, Union
  13. from modelscope.hub.constants import DEFAULT_MODELSCOPE_DATA_ENDPOINT
  14. from modelscope.utils.repo_utils import (CommitInfo, CommitOperation,
  15. CommitOperationAdd)
  16. ignore_file_pattern = [
  17. r'\w+\.bin',
  18. r'\w+\.safetensors',
  19. r'\w+\.pth',
  20. r'\w+\.pt',
  21. r'\w+\.h5',
  22. r'\w+\.ckpt',
  23. r'\w+\.zip',
  24. r'\w+\.onnx',
  25. r'\w+\.tar',
  26. r'\w+\.gz',
  27. ]
  28. def get_all_imported_modules():
  29. """Find all modules in transformers/peft/diffusers"""
  30. all_imported_modules = []
  31. transformers_include_names = [
  32. 'Auto.*',
  33. 'T5.*',
  34. 'BitsAndBytesConfig',
  35. 'GenerationConfig',
  36. 'Awq.*',
  37. 'GPTQ.*',
  38. 'BatchFeature',
  39. 'Qwen.*',
  40. 'Llama.*',
  41. 'Intern.*',
  42. 'Deepseek.*',
  43. 'PretrainedConfig',
  44. 'PreTrainedTokenizer',
  45. 'PreTrainedModel',
  46. 'PreTrainedTokenizerFast',
  47. ]
  48. peft_include_names = ['.*PeftModel.*', '.*Config']
  49. diffusers_include_names = [
  50. '^(?!TF|Flax).*Pipeline$', '^(?!TF|Flax).*Autoencoder.*',
  51. '^(?!TF|Flax).*Model$', '^(?!TF|Flax).*Adapter$', 'ImageProjection',
  52. '^(?!TF|Flax).*UNet$', '^(?!TF|Flax).*Scheduler$'
  53. ]
  54. if importlib.util.find_spec('transformers') is not None:
  55. import transformers
  56. lazy_module = sys.modules['transformers']
  57. _import_structure = lazy_module._import_structure
  58. for key in _import_structure:
  59. if 'dummy' in key.lower():
  60. continue
  61. values = _import_structure[key]
  62. for value in values:
  63. # pretrained
  64. if any([
  65. re.fullmatch(name, value)
  66. for name in transformers_include_names
  67. ]):
  68. try:
  69. module = importlib.import_module(
  70. f'.{key}', transformers.__name__)
  71. value = getattr(module, value)
  72. all_imported_modules.append(value)
  73. except: # noqa
  74. pass
  75. if importlib.util.find_spec('peft') is not None:
  76. try:
  77. import peft
  78. except: # noqa
  79. pass
  80. else:
  81. attributes = dir(peft)
  82. imports = [
  83. attr for attr in attributes if not attr.startswith('__')
  84. ]
  85. all_imported_modules.extend([
  86. getattr(peft, _import) for _import in imports if any([
  87. re.fullmatch(name, _import) for name in peft_include_names
  88. ])
  89. ])
  90. if importlib.util.find_spec('diffusers') is not None:
  91. try:
  92. import diffusers
  93. except: # noqa
  94. pass
  95. else:
  96. lazy_module = sys.modules['diffusers']
  97. if hasattr(lazy_module, '_import_structure'):
  98. _import_structure = lazy_module._import_structure
  99. for key in _import_structure:
  100. if 'dummy' in key.lower():
  101. continue
  102. values = _import_structure[key]
  103. for value in values:
  104. if any([
  105. re.fullmatch(name, value)
  106. for name in diffusers_include_names
  107. ]):
  108. try:
  109. module = importlib.import_module(
  110. f'.{key}', diffusers.__name__)
  111. value = getattr(module, value)
  112. all_imported_modules.append(value)
  113. except: # noqa
  114. pass
  115. else:
  116. attributes = dir(lazy_module)
  117. imports = [
  118. attr for attr in attributes if not attr.startswith('__')
  119. ]
  120. all_imported_modules.extend([
  121. getattr(lazy_module, _import) for _import in imports
  122. if any([
  123. re.fullmatch(name, _import)
  124. for name in diffusers_include_names
  125. ])
  126. ])
  127. return all_imported_modules
  128. def _patch_pretrained_class(all_imported_modules, wrap=False):
  129. """Patch all class to download from modelscope
  130. Args:
  131. wrap: Wrap the class or monkey patch the original class
  132. Returns:
  133. The classes after patched
  134. """
  135. def get_model_dir(pretrained_model_name_or_path,
  136. ignore_file_pattern=None,
  137. allow_file_pattern=None,
  138. **kwargs):
  139. from modelscope import snapshot_download
  140. subfolder = kwargs.pop('subfolder', None)
  141. file_filter = None
  142. if subfolder:
  143. file_filter = f'{subfolder}/*'
  144. if not os.path.exists(pretrained_model_name_or_path):
  145. revision = kwargs.pop('revision', None)
  146. if revision is None or revision == 'main':
  147. revision = 'master'
  148. if file_filter is not None:
  149. allow_file_pattern = file_filter
  150. model_dir = snapshot_download(
  151. pretrained_model_name_or_path,
  152. revision=revision,
  153. ignore_file_pattern=ignore_file_pattern,
  154. allow_file_pattern=allow_file_pattern)
  155. if subfolder:
  156. model_dir = os.path.join(model_dir, subfolder)
  157. else:
  158. model_dir = pretrained_model_name_or_path
  159. return model_dir
  160. def patch_pretrained_model_name_or_path(cls, pretrained_model_name_or_path,
  161. *model_args, **kwargs):
  162. """Patch all from_pretrained"""
  163. model_dir = get_model_dir(pretrained_model_name_or_path,
  164. kwargs.pop('ignore_file_pattern', None),
  165. kwargs.pop('allow_file_pattern', None),
  166. **kwargs)
  167. return cls._from_pretrained_origin.__func__(cls, model_dir,
  168. *model_args, **kwargs)
  169. def patch_get_config_dict(cls, pretrained_model_name_or_path, *model_args,
  170. **kwargs):
  171. """Patch all get_config_dict"""
  172. model_dir = get_model_dir(pretrained_model_name_or_path,
  173. kwargs.pop('ignore_file_pattern', None),
  174. kwargs.pop('allow_file_pattern', None),
  175. **kwargs)
  176. return cls._get_config_dict_origin.__func__(cls, model_dir,
  177. *model_args, **kwargs)
  178. def patch_peft_model_id(cls, model, model_id, *model_args, **kwargs):
  179. """Patch all peft.from_pretrained"""
  180. model_dir = get_model_dir(model_id,
  181. kwargs.pop('ignore_file_pattern', None),
  182. kwargs.pop('allow_file_pattern', None),
  183. **kwargs)
  184. return cls._from_pretrained_origin.__func__(cls, model, model_dir,
  185. *model_args, **kwargs)
  186. def patch_get_peft_type(cls, model_id, **kwargs):
  187. """Patch all _get_peft_type"""
  188. model_dir = get_model_dir(model_id,
  189. kwargs.pop('ignore_file_pattern', None),
  190. kwargs.pop('allow_file_pattern', None),
  191. **kwargs)
  192. return cls._get_peft_type_origin.__func__(cls, model_dir, **kwargs)
  193. def get_wrapped_class(
  194. module_class: 'PreTrainedModel',
  195. ignore_file_pattern: Optional[Union[str, List[str]]] = None,
  196. allow_file_pattern: Optional[Union[str, List[str]]] = None,
  197. **kwargs):
  198. """Get a custom wrapper class for auto classes to download the models from the ModelScope hub
  199. Args:
  200. module_class (`PreTrainedModel`): The actual module class
  201. ignore_file_pattern (`str` or `List`, *optional*, default to `None`):
  202. Any file pattern to be ignored, like exact file names or file extensions.
  203. allow_file_pattern (`str` or `List`, *optional*, default to `None`):
  204. Any file pattern to be included, like exact file names or file extensions.
  205. Returns:
  206. The wrapped class
  207. """
  208. @contextlib.contextmanager
  209. def file_pattern_context(kwargs, module_class, cls):
  210. if 'allow_file_pattern' not in kwargs:
  211. kwargs['allow_file_pattern'] = allow_file_pattern
  212. if 'ignore_file_pattern' not in kwargs:
  213. kwargs['ignore_file_pattern'] = ignore_file_pattern
  214. if kwargs.get(
  215. 'allow_file_pattern') is None and module_class is not None:
  216. extra_allow_file_pattern = None
  217. if 'GenerationConfig' == module_class.__name__:
  218. from transformers.utils import GENERATION_CONFIG_NAME
  219. extra_allow_file_pattern = [
  220. GENERATION_CONFIG_NAME, r'*.py'
  221. ]
  222. elif 'Config' in module_class.__name__:
  223. from transformers import CONFIG_NAME
  224. extra_allow_file_pattern = [CONFIG_NAME, r'*.py']
  225. elif 'Tokenizer' in module_class.__name__:
  226. extra_allow_file_pattern = list(
  227. (cls.vocab_files_names.values()) if cls is not None
  228. and hasattr(cls, 'vocab_files_names') else []) + [
  229. 'chat_template.jinja', r'*.json', r'*.py',
  230. r'*.txt', r'*.model', r'*.tiktoken'
  231. ] # noqa
  232. elif 'Processor' in module_class.__name__:
  233. extra_allow_file_pattern = [
  234. 'chat_template.jinja', r'*.json', r'*.py', r'*.txt',
  235. r'*.model', r'*.tiktoken'
  236. ]
  237. kwargs['allow_file_pattern'] = extra_allow_file_pattern
  238. yield
  239. kwargs.pop('ignore_file_pattern', None)
  240. kwargs.pop('allow_file_pattern', None)
  241. def from_pretrained(model, model_id, *model_args, **kwargs):
  242. with file_pattern_context(kwargs):
  243. # model is an instance
  244. model_dir = get_model_dir(
  245. model_id,
  246. module_class=module_class,
  247. cls=module_class,
  248. **kwargs)
  249. module_obj = module_class.from_pretrained(model, model_dir,
  250. *model_args, **kwargs)
  251. return module_obj
  252. class ClassWrapper(module_class):
  253. @classmethod
  254. def from_pretrained(cls, pretrained_model_name_or_path,
  255. *model_args, **kwargs):
  256. with file_pattern_context(kwargs, module_class, cls):
  257. model_dir = get_model_dir(pretrained_model_name_or_path,
  258. **kwargs)
  259. module_obj = module_class.from_pretrained(
  260. model_dir, *model_args, **kwargs)
  261. if module_class.__name__.startswith('AutoModel'):
  262. module_obj.model_dir = model_dir
  263. return module_obj
  264. @classmethod
  265. def _get_peft_type(cls, model_id, **kwargs):
  266. with file_pattern_context(kwargs, module_class, cls):
  267. model_dir = get_model_dir(
  268. model_id,
  269. ignore_file_pattern=ignore_file_pattern,
  270. allow_file_pattern=allow_file_pattern,
  271. **kwargs)
  272. module_obj = module_class._get_peft_type(model_dir, **kwargs)
  273. return module_obj
  274. @classmethod
  275. def get_config_dict(cls, pretrained_model_name_or_path,
  276. *model_args, **kwargs):
  277. with file_pattern_context(kwargs, module_class, cls):
  278. model_dir = get_model_dir(
  279. pretrained_model_name_or_path,
  280. ignore_file_pattern=ignore_file_pattern,
  281. allow_file_pattern=allow_file_pattern,
  282. **kwargs)
  283. module_obj = module_class.get_config_dict(
  284. model_dir, *model_args, **kwargs)
  285. return module_obj
  286. def save_pretrained(
  287. self,
  288. save_directory: Union[str, os.PathLike],
  289. safe_serialization: bool = True,
  290. **kwargs,
  291. ):
  292. push_to_hub = kwargs.pop('push_to_hub', False)
  293. if push_to_hub:
  294. from modelscope.hub.push_to_hub import push_to_hub
  295. from modelscope.hub.api import HubApi
  296. from modelscope.hub.repository import Repository
  297. token = kwargs.get('token')
  298. commit_message = kwargs.pop('commit_message', None)
  299. repo_name = kwargs.pop(
  300. 'repo_id',
  301. save_directory.split(os.path.sep)[-1])
  302. api = HubApi()
  303. api.login(token)
  304. api.create_repo(repo_name)
  305. # clone the repo
  306. Repository(save_directory, repo_name)
  307. super().save_pretrained(
  308. save_directory=save_directory,
  309. safe_serialization=safe_serialization,
  310. push_to_hub=False,
  311. **kwargs)
  312. # Class members may be unpatched, so push_to_hub is done separately here
  313. if push_to_hub:
  314. push_to_hub(
  315. repo_name=repo_name,
  316. output_dir=save_directory,
  317. commit_message=commit_message,
  318. token=token)
  319. if not hasattr(module_class, 'from_pretrained'):
  320. del ClassWrapper.from_pretrained
  321. else:
  322. parameters = inspect.signature(var.from_pretrained).parameters
  323. if 'model' in parameters and 'model_id' in parameters:
  324. # peft
  325. ClassWrapper.from_pretrained = from_pretrained
  326. if not hasattr(module_class, '_get_peft_type'):
  327. del ClassWrapper._get_peft_type
  328. if not hasattr(module_class, 'get_config_dict'):
  329. del ClassWrapper.get_config_dict
  330. if not hasattr(module_class, 'save_pretrained'):
  331. del ClassWrapper.save_pretrained
  332. ClassWrapper.__name__ = module_class.__name__
  333. ClassWrapper.__qualname__ = module_class.__qualname__
  334. return ClassWrapper
  335. all_available_modules = []
  336. for var in all_imported_modules:
  337. if var is None or not hasattr(var, '__name__'):
  338. continue
  339. name = var.__name__
  340. skip_model = 'tokenizer' in name.lower() or 'config' in name.lower()
  341. if not skip_model:
  342. ignore_file_pattern_kwargs = {}
  343. else:
  344. ignore_file_pattern_kwargs = {
  345. 'ignore_file_pattern': ignore_file_pattern
  346. }
  347. try:
  348. # some TFxxx classes has import errors
  349. has_from_pretrained = hasattr(var, 'from_pretrained')
  350. has_get_peft_type = hasattr(var, '_get_peft_type')
  351. has_get_config_dict = hasattr(var, 'get_config_dict')
  352. has_save_pretrained = hasattr(var, 'save_pretrained')
  353. except: # noqa
  354. continue
  355. # save_pretrained is not a classmethod and cannot be overridden by replacing
  356. # the class method. It requires replacing the class object method.
  357. if wrap or ('pipeline' in name.lower() and has_save_pretrained):
  358. try:
  359. if (not has_from_pretrained and not has_get_config_dict
  360. and not has_get_peft_type and not has_save_pretrained):
  361. all_available_modules.append(var)
  362. else:
  363. all_available_modules.append(
  364. get_wrapped_class(var, **ignore_file_pattern_kwargs))
  365. except: # noqa
  366. all_available_modules.append(var)
  367. else:
  368. if has_from_pretrained and not hasattr(var,
  369. '_from_pretrained_origin'):
  370. parameters = inspect.signature(var.from_pretrained).parameters
  371. # different argument names
  372. is_peft = 'model' in parameters and 'model_id' in parameters
  373. var._from_pretrained_origin = var.from_pretrained
  374. if not is_peft:
  375. var.from_pretrained = classmethod(
  376. partial(patch_pretrained_model_name_or_path,
  377. **ignore_file_pattern_kwargs))
  378. else:
  379. var.from_pretrained = classmethod(
  380. partial(patch_peft_model_id,
  381. **ignore_file_pattern_kwargs))
  382. if has_get_peft_type and not hasattr(var, '_get_peft_type_origin'):
  383. var._get_peft_type_origin = var._get_peft_type
  384. var._get_peft_type = classmethod(
  385. partial(patch_get_peft_type, **ignore_file_pattern_kwargs))
  386. if has_get_config_dict and not hasattr(var,
  387. '_get_config_dict_origin'):
  388. var._get_config_dict_origin = var.get_config_dict
  389. var.get_config_dict = classmethod(
  390. partial(patch_get_config_dict,
  391. **ignore_file_pattern_kwargs))
  392. all_available_modules.append(var)
  393. def get_class_from_dynamic_module(class_reference, *args, **kwargs):
  394. from transformers.dynamic_module_utils import origin_get_class_from_dynamic_module
  395. if '--' in class_reference:
  396. repo_id, class_reference = class_reference.split('--')
  397. if not os.path.exists(repo_id):
  398. from modelscope import snapshot_download
  399. repo_id = snapshot_download(repo_id)
  400. class_reference = repo_id + '--' + class_reference
  401. return origin_get_class_from_dynamic_module(class_reference, *args,
  402. **kwargs)
  403. from transformers import dynamic_module_utils
  404. if not hasattr(dynamic_module_utils,
  405. 'origin_get_class_from_dynamic_module'):
  406. dynamic_module_utils.origin_get_class_from_dynamic_module = dynamic_module_utils.get_class_from_dynamic_module
  407. dynamic_module_utils.get_class_from_dynamic_module = get_class_from_dynamic_module
  408. from transformers.models.auto import configuration_auto
  409. configuration_auto.get_class_from_dynamic_module = get_class_from_dynamic_module
  410. return all_available_modules
  411. def _unpatch_pretrained_class(all_imported_modules):
  412. for var in all_imported_modules:
  413. if var is None:
  414. continue
  415. try:
  416. has_from_pretrained = hasattr(var, 'from_pretrained')
  417. has_get_peft_type = hasattr(var, '_get_peft_type')
  418. has_get_config_dict = hasattr(var, 'get_config_dict')
  419. except: # noqa
  420. continue
  421. if has_from_pretrained and hasattr(var, '_from_pretrained_origin'):
  422. var.from_pretrained = var._from_pretrained_origin
  423. try:
  424. delattr(var, '_from_pretrained_origin')
  425. except: # noqa
  426. pass
  427. if has_get_peft_type and hasattr(var, '_get_peft_type_origin'):
  428. var._get_peft_type = var._get_peft_type_origin
  429. try:
  430. delattr(var, '_get_peft_type_origin')
  431. except: # noqa
  432. pass
  433. if has_get_config_dict and hasattr(var, '_get_config_dict_origin'):
  434. var.get_config_dict = var._get_config_dict_origin
  435. try:
  436. delattr(var, '_get_config_dict_origin')
  437. except: # noqa
  438. pass
  439. from transformers import dynamic_module_utils
  440. if hasattr(dynamic_module_utils, 'origin_get_class_from_dynamic_module'):
  441. dynamic_module_utils.get_class_from_dynamic_module = dynamic_module_utils.origin_get_class_from_dynamic_module
  442. from transformers.models.auto import configuration_auto
  443. configuration_auto.get_class_from_dynamic_module = dynamic_module_utils.origin_get_class_from_dynamic_module
  444. delattr(dynamic_module_utils, 'origin_get_class_from_dynamic_module')
  445. def _patch_hub():
  446. import huggingface_hub
  447. from huggingface_hub import hf_api
  448. from huggingface_hub.hf_api import api
  449. from huggingface_hub.hf_api import future_compatible
  450. from modelscope import get_logger
  451. logger = get_logger()
  452. def _file_exists(
  453. self,
  454. repo_id: str,
  455. filename: str,
  456. *,
  457. repo_type: Optional[str] = None,
  458. revision: Optional[str] = None,
  459. token: Union[str, bool, None] = None,
  460. ):
  461. """Patch huggingface_hub.file_exists"""
  462. if repo_type is not None:
  463. logger.warning(
  464. 'The passed in repo_type will not be used in modelscope. Now only model repo can be queried.'
  465. )
  466. from modelscope.hub.api import HubApi
  467. api = HubApi()
  468. api.login(token)
  469. if revision is None or revision == 'main':
  470. revision = 'master'
  471. return api.file_exists(repo_id, filename, revision=revision)
  472. def _file_download(repo_id: str,
  473. filename: str,
  474. *,
  475. subfolder: Optional[str] = None,
  476. repo_type: Optional[str] = None,
  477. revision: Optional[str] = None,
  478. cache_dir: Union[str, Path, None] = None,
  479. local_dir: Union[str, Path, None] = None,
  480. token: Union[bool, str, None] = None,
  481. local_files_only: bool = False,
  482. **kwargs):
  483. """Patch huggingface_hub.hf_hub_download"""
  484. if len(kwargs) > 0:
  485. logger.warning(
  486. 'The passed in library_name,library_version,user_agent,force_download,proxies'
  487. 'etag_timeout,headers,endpoint '
  488. 'will not be used in modelscope.')
  489. assert repo_type in (
  490. None, 'model',
  491. 'dataset'), f'repo_type={repo_type} is not supported in ModelScope'
  492. if repo_type in (None, 'model'):
  493. from modelscope.hub.file_download import model_file_download as file_download
  494. else:
  495. from modelscope.hub.file_download import dataset_file_download as file_download
  496. from modelscope import HubApi
  497. api = HubApi()
  498. api.login(token)
  499. if revision is None or revision == 'main':
  500. revision = 'master'
  501. return file_download(
  502. repo_id,
  503. file_path=os.path.join(subfolder, filename)
  504. if subfolder else filename,
  505. cache_dir=cache_dir,
  506. local_dir=local_dir,
  507. local_files_only=local_files_only,
  508. revision=revision)
  509. def _whoami(self, token: Union[bool, str, None] = None) -> Dict:
  510. from modelscope.hub.api import ModelScopeConfig
  511. from modelscope.hub.api import HubApi
  512. api = HubApi()
  513. api.login(token)
  514. return {'name': ModelScopeConfig.get_user_info()[0] or 'unknown'}
  515. def create_repo(self,
  516. repo_id: str,
  517. *,
  518. token: Union[str, bool, None] = None,
  519. private: bool = False,
  520. **kwargs) -> 'RepoUrl':
  521. """
  522. Create a new repository on the hub.
  523. Args:
  524. repo_id: The ID of the repository to create.
  525. token: The authentication token to use.
  526. private: Whether the repository should be private.
  527. **kwargs: Additional arguments.
  528. Returns:
  529. RepoUrl: The URL of the created repository.
  530. """
  531. from modelscope.hub.api import HubApi
  532. api = HubApi()
  533. visibility = 'private' if private else 'public'
  534. repo_url = api.create_repo(
  535. repo_id, token=token, visibility=visibility, **kwargs)
  536. from modelscope.utils.repo_utils import RepoUrl
  537. return RepoUrl(url=repo_url, repo_type='model', repo_id=repo_id)
  538. @future_compatible
  539. def upload_folder(
  540. self,
  541. *,
  542. repo_id: str,
  543. folder_path: Union[str, Path],
  544. path_in_repo: Optional[str] = None,
  545. commit_message: Optional[str] = None,
  546. commit_description: Optional[str] = None,
  547. token: Union[str, bool, None] = None,
  548. revision: Optional[str] = 'master',
  549. ignore_patterns: Optional[Union[List[str], str]] = None,
  550. **kwargs,
  551. ):
  552. from modelscope.hub.push_to_hub import _push_files_to_hub
  553. if revision is None or revision == 'main':
  554. revision = 'master'
  555. _push_files_to_hub(
  556. path_or_fileobj=folder_path,
  557. path_in_repo=path_in_repo,
  558. repo_id=repo_id,
  559. commit_message=commit_message,
  560. commit_description=commit_description,
  561. revision=revision,
  562. token=token)
  563. from modelscope.utils.repo_utils import CommitInfo
  564. return CommitInfo(
  565. commit_url=
  566. f'{DEFAULT_MODELSCOPE_DATA_ENDPOINT}/models/{repo_id}/files',
  567. commit_message=commit_message,
  568. commit_description=commit_description,
  569. oid=None,
  570. )
  571. from modelscope.utils.constant import DEFAULT_REPOSITORY_REVISION
  572. @future_compatible
  573. def upload_file(
  574. self,
  575. *,
  576. path_or_fileobj: Union[str, Path, bytes, BinaryIO],
  577. path_in_repo: str,
  578. repo_id: str,
  579. token: Union[str, bool, None] = None,
  580. revision: Optional[str] = DEFAULT_REPOSITORY_REVISION,
  581. commit_message: Optional[str] = None,
  582. commit_description: Optional[str] = None,
  583. **kwargs,
  584. ):
  585. if revision is None or revision == 'main':
  586. revision = 'master'
  587. from modelscope.hub.push_to_hub import _push_files_to_hub
  588. _push_files_to_hub(path_or_fileobj, path_in_repo, repo_id, token,
  589. revision, commit_message, commit_description)
  590. @future_compatible
  591. def create_commit(
  592. self,
  593. repo_id: str,
  594. operations: Iterable[CommitOperation],
  595. *,
  596. commit_message: str,
  597. commit_description: Optional[str] = None,
  598. token: Union[str, bool, None] = None,
  599. repo_type: Optional[str] = None,
  600. revision: Optional[str] = DEFAULT_REPOSITORY_REVISION,
  601. **kwargs,
  602. ) -> Union[CommitInfo, Future[CommitInfo]]:
  603. from modelscope.hub.api import HubApi
  604. api = HubApi()
  605. if any(['Add' not in op.__class__.__name__ for op in operations]):
  606. raise ValueError(
  607. 'ModelScope create_commit only support Add operation for now.')
  608. if revision is None or revision == 'main':
  609. revision = 'master'
  610. all_files = [op.path_or_fileobj for op in operations]
  611. api.upload_folder(
  612. repo_id=repo_id,
  613. folder_path=all_files,
  614. commit_message=commit_message,
  615. commit_description=commit_description,
  616. token=token,
  617. revision=revision,
  618. repo_type=repo_type or 'model')
  619. def load(
  620. cls,
  621. repo_id_or_path: Union[str, Path],
  622. repo_type: Optional[str] = None,
  623. token: Optional[str] = None,
  624. ignore_metadata_errors: bool = False,
  625. ):
  626. from modelscope.hub.api import HubApi
  627. api = HubApi()
  628. api.login(token)
  629. if os.path.exists(repo_id_or_path):
  630. file_path = repo_id_or_path
  631. elif repo_type == 'model' or repo_type is None:
  632. from modelscope import model_file_download
  633. file_path = model_file_download(repo_id_or_path, 'README.md')
  634. elif repo_type == 'dataset':
  635. from modelscope import dataset_file_download
  636. file_path = dataset_file_download(repo_id_or_path, 'README.md')
  637. else:
  638. raise ValueError(
  639. f'repo_type should be `model` or `dataset`, but now is {repo_type}'
  640. )
  641. with open(file_path, 'r') as f:
  642. repo_card = cls(
  643. f.read(), ignore_metadata_errors=ignore_metadata_errors)
  644. if not hasattr(repo_card.data, 'tags'):
  645. repo_card.data.tags = []
  646. return repo_card
  647. # Patch repocard.validate
  648. from huggingface_hub import repocard
  649. if not hasattr(repocard.RepoCard, '_validate_origin'):
  650. repocard.RepoCard._validate_origin = repocard.RepoCard.validate
  651. repocard.RepoCard.validate = lambda *args, **kwargs: None
  652. repocard.RepoCard._load_origin = repocard.RepoCard.load
  653. repocard.RepoCard.load = MethodType(load, repocard.RepoCard)
  654. if not hasattr(hf_api, '_hf_hub_download_origin'):
  655. # Patch hf_hub_download
  656. hf_api._hf_hub_download_origin = huggingface_hub.file_download.hf_hub_download
  657. huggingface_hub.hf_hub_download = _file_download
  658. huggingface_hub.file_download.hf_hub_download = _file_download
  659. if not hasattr(hf_api, '_file_exists_origin'):
  660. # Patch file_exists
  661. hf_api._file_exists_origin = hf_api.file_exists
  662. hf_api.file_exists = MethodType(_file_exists, api)
  663. huggingface_hub.file_exists = hf_api.file_exists
  664. huggingface_hub.hf_api.file_exists = hf_api.file_exists
  665. if not hasattr(hf_api, '_whoami_origin'):
  666. # Patch whoami
  667. hf_api._whoami_origin = hf_api.whoami
  668. hf_api.whoami = MethodType(_whoami, api)
  669. huggingface_hub.whoami = hf_api.whoami
  670. huggingface_hub.hf_api.whoami = hf_api.whoami
  671. if not hasattr(hf_api, '_create_repo_origin'):
  672. # Patch create_repo
  673. from transformers.utils import hub
  674. hf_api._create_repo_origin = hf_api.create_repo
  675. hf_api.create_repo = MethodType(create_repo, api)
  676. huggingface_hub.create_repo = hf_api.create_repo
  677. huggingface_hub.hf_api.create_repo = hf_api.create_repo
  678. hub.create_repo = hf_api.create_repo
  679. if not hasattr(hf_api, '_upload_folder_origin'):
  680. # Patch upload_folder
  681. hf_api._upload_folder_origin = hf_api.upload_folder
  682. hf_api.upload_folder = MethodType(upload_folder, api)
  683. huggingface_hub.upload_folder = hf_api.upload_folder
  684. huggingface_hub.hf_api.upload_folder = hf_api.upload_folder
  685. if not hasattr(hf_api, '_upload_file_origin'):
  686. # Patch upload_file
  687. hf_api._upload_file_origin = hf_api.upload_file
  688. hf_api.upload_file = MethodType(upload_file, api)
  689. huggingface_hub.upload_file = hf_api.upload_file
  690. huggingface_hub.hf_api.upload_file = hf_api.upload_file
  691. repocard.upload_file = hf_api.upload_file
  692. if not hasattr(hf_api, '_create_commit_origin'):
  693. # Patch upload_file
  694. hf_api._create_commit_origin = hf_api.create_commit
  695. hf_api.create_commit = MethodType(create_commit, api)
  696. huggingface_hub.create_commit = hf_api.create_commit
  697. huggingface_hub.hf_api.create_commit = hf_api.create_commit
  698. from transformers.utils import hub
  699. hub.create_commit = hf_api.create_commit
  700. def _unpatch_hub():
  701. import huggingface_hub
  702. from huggingface_hub import hf_api
  703. from huggingface_hub import repocard
  704. if hasattr(repocard.RepoCard, '_validate_origin'):
  705. repocard.RepoCard.validate = repocard.RepoCard._validate_origin
  706. delattr(repocard.RepoCard, '_validate_origin')
  707. if hasattr(repocard.RepoCard, '_load_origin'):
  708. repocard.RepoCard.load = repocard.RepoCard._load_origin
  709. delattr(repocard.RepoCard, '_load_origin')
  710. if hasattr(hf_api, '_hf_hub_download_origin'):
  711. huggingface_hub.file_download.hf_hub_download = hf_api._hf_hub_download_origin
  712. huggingface_hub.hf_hub_download = hf_api._hf_hub_download_origin
  713. huggingface_hub.file_download.hf_hub_download = hf_api._hf_hub_download_origin
  714. delattr(hf_api, '_hf_hub_download_origin')
  715. if hasattr(hf_api, '_file_exists_origin'):
  716. hf_api.file_exists = hf_api._file_exists_origin
  717. huggingface_hub.file_exists = hf_api.file_exists
  718. huggingface_hub.hf_api.file_exists = hf_api.file_exists
  719. delattr(hf_api, '_file_exists_origin')
  720. if hasattr(hf_api, '_whoami_origin'):
  721. hf_api.whoami = hf_api._whoami_origin
  722. huggingface_hub.whoami = hf_api.whoami
  723. huggingface_hub.hf_api.whoami = hf_api.whoami
  724. delattr(hf_api, '_whoami_origin')
  725. if hasattr(hf_api, '_create_repo_origin'):
  726. from transformers.utils import hub
  727. hf_api.create_repo = hf_api._create_repo_origin
  728. huggingface_hub.create_repo = hf_api.create_repo
  729. huggingface_hub.hf_api.create_repo = hf_api.create_repo
  730. hub.create_repo = hf_api.create_repo
  731. delattr(hf_api, '_create_repo_origin')
  732. if hasattr(hf_api, '_upload_folder_origin'):
  733. hf_api.upload_folder = hf_api._upload_folder_origin
  734. huggingface_hub.upload_folder = hf_api.upload_folder
  735. huggingface_hub.hf_api.upload_folder = hf_api.upload_folder
  736. delattr(hf_api, '_upload_folder_origin')
  737. if hasattr(hf_api, '_upload_file_origin'):
  738. hf_api.upload_file = hf_api._upload_file_origin
  739. huggingface_hub.upload_file = hf_api.upload_file
  740. huggingface_hub.hf_api.upload_file = hf_api.upload_file
  741. repocard.upload_file = hf_api.upload_file
  742. delattr(hf_api, '_upload_file_origin')
  743. if hasattr(hf_api, '_create_commit_origin'):
  744. hf_api.create_commit = hf_api._create_commit_origin
  745. huggingface_hub.create_commit = hf_api.create_commit
  746. huggingface_hub.hf_api.create_commit = hf_api.create_commit
  747. from transformers.utils import hub
  748. hub.create_commit = hf_api.create_commit
  749. delattr(hf_api, '_create_commit_origin')
  750. def patch_hub():
  751. _patch_hub()
  752. _patch_pretrained_class(get_all_imported_modules())
  753. def unpatch_hub():
  754. _unpatch_pretrained_class(get_all_imported_modules())
  755. _unpatch_hub()
  756. @contextlib.contextmanager
  757. def patch_context():
  758. patch_hub()
  759. yield
  760. unpatch_hub()