hub_mixin.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853
  1. import inspect
  2. import json
  3. import os
  4. from dataclasses import Field, asdict, dataclass, is_dataclass
  5. from pathlib import Path
  6. from typing import Any, Callable, ClassVar, Dict, List, Optional, Protocol, Tuple, Type, TypeVar, Union
  7. import packaging.version
  8. from . import constants
  9. from .errors import EntryNotFoundError, HfHubHTTPError
  10. from .file_download import hf_hub_download
  11. from .hf_api import HfApi
  12. from .repocard import ModelCard, ModelCardData
  13. from .utils import (
  14. SoftTemporaryDirectory,
  15. is_jsonable,
  16. is_safetensors_available,
  17. is_simple_optional_type,
  18. is_torch_available,
  19. logging,
  20. unwrap_simple_optional_type,
  21. validate_hf_hub_args,
  22. )
  23. if is_torch_available():
  24. import torch # type: ignore
  25. if is_safetensors_available():
  26. import safetensors
  27. from safetensors.torch import load_model as load_model_as_safetensor
  28. from safetensors.torch import save_model as save_model_as_safetensor
  29. logger = logging.get_logger(__name__)
  30. # Type alias for dataclass instances, copied from https://github.com/python/typeshed/blob/9f28171658b9ca6c32a7cb93fbb99fc92b17858b/stdlib/_typeshed/__init__.pyi#L349
  31. class DataclassInstance(Protocol):
  32. __dataclass_fields__: ClassVar[Dict[str, Field]]
  33. # Generic variable that is either ModelHubMixin or a subclass thereof
  34. T = TypeVar("T", bound="ModelHubMixin")
  35. # Generic variable to represent an args type
  36. ARGS_T = TypeVar("ARGS_T")
  37. ENCODER_T = Callable[[ARGS_T], Any]
  38. DECODER_T = Callable[[Any], ARGS_T]
  39. CODER_T = Tuple[ENCODER_T, DECODER_T]
  40. DEFAULT_MODEL_CARD = """
  41. ---
  42. # For reference on model card metadata, see the spec: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1
  43. # Doc / guide: https://huggingface.co/docs/hub/model-cards
  44. {{ card_data }}
  45. ---
  46. This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration:
  47. - Code: {{ repo_url | default("[More Information Needed]", true) }}
  48. - Paper: {{ paper_url | default("[More Information Needed]", true) }}
  49. - Docs: {{ docs_url | default("[More Information Needed]", true) }}
  50. """
  51. @dataclass
  52. class MixinInfo:
  53. model_card_template: str
  54. model_card_data: ModelCardData
  55. docs_url: Optional[str] = None
  56. paper_url: Optional[str] = None
  57. repo_url: Optional[str] = None
  58. class ModelHubMixin:
  59. """
  60. A generic mixin to integrate ANY machine learning framework with the Hub.
  61. To integrate your framework, your model class must inherit from this class. Custom logic for saving/loading models
  62. have to be overwritten in [`_from_pretrained`] and [`_save_pretrained`]. [`PyTorchModelHubMixin`] is a good example
  63. of mixin integration with the Hub. Check out our [integration guide](../guides/integrations) for more instructions.
  64. When inheriting from [`ModelHubMixin`], you can define class-level attributes. These attributes are not passed to
  65. `__init__` but to the class definition itself. This is useful to define metadata about the library integrating
  66. [`ModelHubMixin`].
  67. For more details on how to integrate the mixin with your library, checkout the [integration guide](../guides/integrations).
  68. Args:
  69. repo_url (`str`, *optional*):
  70. URL of the library repository. Used to generate model card.
  71. paper_url (`str`, *optional*):
  72. URL of the library paper. Used to generate model card.
  73. docs_url (`str`, *optional*):
  74. URL of the library documentation. Used to generate model card.
  75. model_card_template (`str`, *optional*):
  76. Template of the model card. Used to generate model card. Defaults to a generic template.
  77. language (`str` or `List[str]`, *optional*):
  78. Language supported by the library. Used to generate model card.
  79. library_name (`str`, *optional*):
  80. Name of the library integrating ModelHubMixin. Used to generate model card.
  81. license (`str`, *optional*):
  82. License of the library integrating ModelHubMixin. Used to generate model card.
  83. E.g: "apache-2.0"
  84. license_name (`str`, *optional*):
  85. Name of the library integrating ModelHubMixin. Used to generate model card.
  86. Only used if `license` is set to `other`.
  87. E.g: "coqui-public-model-license".
  88. license_link (`str`, *optional*):
  89. URL to the license of the library integrating ModelHubMixin. Used to generate model card.
  90. Only used if `license` is set to `other` and `license_name` is set.
  91. E.g: "https://coqui.ai/cpml".
  92. pipeline_tag (`str`, *optional*):
  93. Tag of the pipeline. Used to generate model card. E.g. "text-classification".
  94. tags (`List[str]`, *optional*):
  95. Tags to be added to the model card. Used to generate model card. E.g. ["computer-vision"]
  96. coders (`Dict[Type, Tuple[Callable, Callable]]`, *optional*):
  97. Dictionary of custom types and their encoders/decoders. Used to encode/decode arguments that are not
  98. jsonable by default. E.g dataclasses, argparse.Namespace, OmegaConf, etc.
  99. Example:
  100. ```python
  101. >>> from huggingface_hub import ModelHubMixin
  102. # Inherit from ModelHubMixin
  103. >>> class MyCustomModel(
  104. ... ModelHubMixin,
  105. ... library_name="my-library",
  106. ... tags=["computer-vision"],
  107. ... repo_url="https://github.com/huggingface/my-cool-library",
  108. ... paper_url="https://arxiv.org/abs/2304.12244",
  109. ... docs_url="https://huggingface.co/docs/my-cool-library",
  110. ... # ^ optional metadata to generate model card
  111. ... ):
  112. ... def __init__(self, size: int = 512, device: str = "cpu"):
  113. ... # define how to initialize your model
  114. ... super().__init__()
  115. ... ...
  116. ...
  117. ... def _save_pretrained(self, save_directory: Path) -> None:
  118. ... # define how to serialize your model
  119. ... ...
  120. ...
  121. ... @classmethod
  122. ... def from_pretrained(
  123. ... cls: Type[T],
  124. ... pretrained_model_name_or_path: Union[str, Path],
  125. ... *,
  126. ... force_download: bool = False,
  127. ... resume_download: Optional[bool] = None,
  128. ... proxies: Optional[Dict] = None,
  129. ... token: Optional[Union[str, bool]] = None,
  130. ... cache_dir: Optional[Union[str, Path]] = None,
  131. ... local_files_only: bool = False,
  132. ... revision: Optional[str] = None,
  133. ... **model_kwargs,
  134. ... ) -> T:
  135. ... # define how to deserialize your model
  136. ... ...
  137. >>> model = MyCustomModel(size=256, device="gpu")
  138. # Save model weights to local directory
  139. >>> model.save_pretrained("my-awesome-model")
  140. # Push model weights to the Hub
  141. >>> model.push_to_hub("my-awesome-model")
  142. # Download and initialize weights from the Hub
  143. >>> reloaded_model = MyCustomModel.from_pretrained("username/my-awesome-model")
  144. >>> reloaded_model.size
  145. 256
  146. # Model card has been correctly populated
  147. >>> from huggingface_hub import ModelCard
  148. >>> card = ModelCard.load("username/my-awesome-model")
  149. >>> card.data.tags
  150. ["x-custom-tag", "pytorch_model_hub_mixin", "model_hub_mixin"]
  151. >>> card.data.library_name
  152. "my-library"
  153. ```
  154. """
  155. _hub_mixin_config: Optional[Union[dict, DataclassInstance]] = None
  156. # ^ optional config attribute automatically set in `from_pretrained`
  157. _hub_mixin_info: MixinInfo
  158. # ^ information about the library integrating ModelHubMixin (used to generate model card)
  159. _hub_mixin_inject_config: bool # whether `_from_pretrained` expects `config` or not
  160. _hub_mixin_init_parameters: Dict[str, inspect.Parameter] # __init__ parameters
  161. _hub_mixin_jsonable_default_values: Dict[str, Any] # default values for __init__ parameters
  162. _hub_mixin_jsonable_custom_types: Tuple[Type, ...] # custom types that can be encoded/decoded
  163. _hub_mixin_coders: Dict[Type, CODER_T] # encoders/decoders for custom types
  164. # ^ internal values to handle config
  165. def __init_subclass__(
  166. cls,
  167. *,
  168. # Generic info for model card
  169. repo_url: Optional[str] = None,
  170. paper_url: Optional[str] = None,
  171. docs_url: Optional[str] = None,
  172. # Model card template
  173. model_card_template: str = DEFAULT_MODEL_CARD,
  174. # Model card metadata
  175. language: Optional[List[str]] = None,
  176. library_name: Optional[str] = None,
  177. license: Optional[str] = None,
  178. license_name: Optional[str] = None,
  179. license_link: Optional[str] = None,
  180. pipeline_tag: Optional[str] = None,
  181. tags: Optional[List[str]] = None,
  182. # How to encode/decode arguments with custom type into a JSON config?
  183. coders: Optional[
  184. Dict[Type, CODER_T]
  185. # Key is a type.
  186. # Value is a tuple (encoder, decoder).
  187. # Example: {MyCustomType: (lambda x: x.value, lambda data: MyCustomType(data))}
  188. ] = None,
  189. ) -> None:
  190. """Inspect __init__ signature only once when subclassing + handle modelcard."""
  191. super().__init_subclass__()
  192. # Will be reused when creating modelcard
  193. tags = tags or []
  194. tags.append("model_hub_mixin")
  195. # Initialize MixinInfo if not existent
  196. info = MixinInfo(model_card_template=model_card_template, model_card_data=ModelCardData())
  197. # If parent class has a MixinInfo, inherit from it as a copy
  198. if hasattr(cls, "_hub_mixin_info"):
  199. # Inherit model card template from parent class if not explicitly set
  200. if model_card_template == DEFAULT_MODEL_CARD:
  201. info.model_card_template = cls._hub_mixin_info.model_card_template
  202. # Inherit from parent model card data
  203. info.model_card_data = ModelCardData(**cls._hub_mixin_info.model_card_data.to_dict())
  204. # Inherit other info
  205. info.docs_url = cls._hub_mixin_info.docs_url
  206. info.paper_url = cls._hub_mixin_info.paper_url
  207. info.repo_url = cls._hub_mixin_info.repo_url
  208. cls._hub_mixin_info = info
  209. # Update MixinInfo with metadata
  210. if model_card_template is not None and model_card_template != DEFAULT_MODEL_CARD:
  211. info.model_card_template = model_card_template
  212. if repo_url is not None:
  213. info.repo_url = repo_url
  214. if paper_url is not None:
  215. info.paper_url = paper_url
  216. if docs_url is not None:
  217. info.docs_url = docs_url
  218. if language is not None:
  219. info.model_card_data.language = language
  220. if library_name is not None:
  221. info.model_card_data.library_name = library_name
  222. if license is not None:
  223. info.model_card_data.license = license
  224. if license_name is not None:
  225. info.model_card_data.license_name = license_name
  226. if license_link is not None:
  227. info.model_card_data.license_link = license_link
  228. if pipeline_tag is not None:
  229. info.model_card_data.pipeline_tag = pipeline_tag
  230. if tags is not None:
  231. normalized_tags = list(tags)
  232. if info.model_card_data.tags is not None:
  233. info.model_card_data.tags.extend(normalized_tags)
  234. else:
  235. info.model_card_data.tags = normalized_tags
  236. if info.model_card_data.tags is not None:
  237. info.model_card_data.tags = sorted(set(info.model_card_data.tags))
  238. # Handle encoders/decoders for args
  239. cls._hub_mixin_coders = coders or {}
  240. cls._hub_mixin_jsonable_custom_types = tuple(cls._hub_mixin_coders.keys())
  241. # Inspect __init__ signature to handle config
  242. cls._hub_mixin_init_parameters = dict(inspect.signature(cls.__init__).parameters)
  243. cls._hub_mixin_jsonable_default_values = {
  244. param.name: cls._encode_arg(param.default)
  245. for param in cls._hub_mixin_init_parameters.values()
  246. if param.default is not inspect.Parameter.empty and cls._is_jsonable(param.default)
  247. }
  248. cls._hub_mixin_inject_config = "config" in inspect.signature(cls._from_pretrained).parameters
  249. def __new__(cls: Type[T], *args, **kwargs) -> T:
  250. """Create a new instance of the class and handle config.
  251. 3 cases:
  252. - If `self._hub_mixin_config` is already set, do nothing.
  253. - If `config` is passed as a dataclass, set it as `self._hub_mixin_config`.
  254. - Otherwise, build `self._hub_mixin_config` from default values and passed values.
  255. """
  256. instance = super().__new__(cls)
  257. # If `config` is already set, return early
  258. if instance._hub_mixin_config is not None:
  259. return instance
  260. # Infer passed values
  261. passed_values = {
  262. **{
  263. key: value
  264. for key, value in zip(
  265. # [1:] to skip `self` parameter
  266. list(cls._hub_mixin_init_parameters)[1:],
  267. args,
  268. )
  269. },
  270. **kwargs,
  271. }
  272. # If config passed as dataclass => set it and return early
  273. if is_dataclass(passed_values.get("config")):
  274. instance._hub_mixin_config = passed_values["config"]
  275. return instance
  276. # Otherwise, build config from default + passed values
  277. init_config = {
  278. # default values
  279. **cls._hub_mixin_jsonable_default_values,
  280. # passed values
  281. **{
  282. key: cls._encode_arg(value) # Encode custom types as jsonable value
  283. for key, value in passed_values.items()
  284. if instance._is_jsonable(value) # Only if jsonable or we have a custom encoder
  285. },
  286. }
  287. passed_config = init_config.pop("config", {})
  288. # Populate `init_config` with provided config
  289. if isinstance(passed_config, dict):
  290. init_config.update(passed_config)
  291. # Set `config` attribute and return
  292. if init_config != {}:
  293. instance._hub_mixin_config = init_config
  294. return instance
  295. @classmethod
  296. def _is_jsonable(cls, value: Any) -> bool:
  297. """Check if a value is JSON serializable."""
  298. if is_dataclass(value):
  299. return True
  300. if isinstance(value, cls._hub_mixin_jsonable_custom_types):
  301. return True
  302. return is_jsonable(value)
  303. @classmethod
  304. def _encode_arg(cls, arg: Any) -> Any:
  305. """Encode an argument into a JSON serializable format."""
  306. if is_dataclass(arg):
  307. return asdict(arg) # type: ignore[arg-type]
  308. for type_, (encoder, _) in cls._hub_mixin_coders.items():
  309. if isinstance(arg, type_):
  310. if arg is None:
  311. return None
  312. return encoder(arg)
  313. return arg
  314. @classmethod
  315. def _decode_arg(cls, expected_type: Type[ARGS_T], value: Any) -> Optional[ARGS_T]:
  316. """Decode a JSON serializable value into an argument."""
  317. if is_simple_optional_type(expected_type):
  318. if value is None:
  319. return None
  320. expected_type = unwrap_simple_optional_type(expected_type)
  321. # Dataclass => handle it
  322. if is_dataclass(expected_type):
  323. return _load_dataclass(expected_type, value) # type: ignore[return-value]
  324. # Otherwise => check custom decoders
  325. for type_, (_, decoder) in cls._hub_mixin_coders.items():
  326. if inspect.isclass(expected_type) and issubclass(expected_type, type_):
  327. return decoder(value)
  328. # Otherwise => don't decode
  329. return value
  330. def save_pretrained(
  331. self,
  332. save_directory: Union[str, Path],
  333. *,
  334. config: Optional[Union[dict, DataclassInstance]] = None,
  335. repo_id: Optional[str] = None,
  336. push_to_hub: bool = False,
  337. model_card_kwargs: Optional[Dict[str, Any]] = None,
  338. **push_to_hub_kwargs,
  339. ) -> Optional[str]:
  340. """
  341. Save weights in local directory.
  342. Args:
  343. save_directory (`str` or `Path`):
  344. Path to directory in which the model weights and configuration will be saved.
  345. config (`dict` or `DataclassInstance`, *optional*):
  346. Model configuration specified as a key/value dictionary or a dataclass instance.
  347. push_to_hub (`bool`, *optional*, defaults to `False`):
  348. Whether or not to push your model to the Huggingface Hub after saving it.
  349. repo_id (`str`, *optional*):
  350. ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to the folder name if
  351. not provided.
  352. model_card_kwargs (`Dict[str, Any]`, *optional*):
  353. Additional arguments passed to the model card template to customize the model card.
  354. push_to_hub_kwargs:
  355. Additional key word arguments passed along to the [`~ModelHubMixin.push_to_hub`] method.
  356. Returns:
  357. `str` or `None`: url of the commit on the Hub if `push_to_hub=True`, `None` otherwise.
  358. """
  359. save_directory = Path(save_directory)
  360. save_directory.mkdir(parents=True, exist_ok=True)
  361. # Remove config.json if already exists. After `_save_pretrained` we don't want to overwrite config.json
  362. # as it might have been saved by the custom `_save_pretrained` already. However we do want to overwrite
  363. # an existing config.json if it was not saved by `_save_pretrained`.
  364. config_path = save_directory / constants.CONFIG_NAME
  365. config_path.unlink(missing_ok=True)
  366. # save model weights/files (framework-specific)
  367. self._save_pretrained(save_directory)
  368. # save config (if provided and if not serialized yet in `_save_pretrained`)
  369. if config is None:
  370. config = self._hub_mixin_config
  371. if config is not None:
  372. if is_dataclass(config):
  373. config = asdict(config) # type: ignore[arg-type]
  374. if not config_path.exists():
  375. config_str = json.dumps(config, sort_keys=True, indent=2)
  376. config_path.write_text(config_str)
  377. # save model card
  378. model_card_path = save_directory / "README.md"
  379. model_card_kwargs = model_card_kwargs if model_card_kwargs is not None else {}
  380. if not model_card_path.exists(): # do not overwrite if already exists
  381. self.generate_model_card(**model_card_kwargs).save(save_directory / "README.md")
  382. # push to the Hub if required
  383. if push_to_hub:
  384. kwargs = push_to_hub_kwargs.copy() # soft-copy to avoid mutating input
  385. if config is not None: # kwarg for `push_to_hub`
  386. kwargs["config"] = config
  387. if repo_id is None:
  388. repo_id = save_directory.name # Defaults to `save_directory` name
  389. return self.push_to_hub(repo_id=repo_id, model_card_kwargs=model_card_kwargs, **kwargs)
  390. return None
  391. def _save_pretrained(self, save_directory: Path) -> None:
  392. """
  393. Overwrite this method in subclass to define how to save your model.
  394. Check out our [integration guide](../guides/integrations) for instructions.
  395. Args:
  396. save_directory (`str` or `Path`):
  397. Path to directory in which the model weights and configuration will be saved.
  398. """
  399. raise NotImplementedError
  400. @classmethod
  401. @validate_hf_hub_args
  402. def from_pretrained(
  403. cls: Type[T],
  404. pretrained_model_name_or_path: Union[str, Path],
  405. *,
  406. force_download: bool = False,
  407. resume_download: Optional[bool] = None,
  408. proxies: Optional[Dict] = None,
  409. token: Optional[Union[str, bool]] = None,
  410. cache_dir: Optional[Union[str, Path]] = None,
  411. local_files_only: bool = False,
  412. revision: Optional[str] = None,
  413. **model_kwargs,
  414. ) -> T:
  415. """
  416. Download a model from the Huggingface Hub and instantiate it.
  417. Args:
  418. pretrained_model_name_or_path (`str`, `Path`):
  419. - Either the `model_id` (string) of a model hosted on the Hub, e.g. `bigscience/bloom`.
  420. - Or a path to a `directory` containing model weights saved using
  421. [`~transformers.PreTrainedModel.save_pretrained`], e.g., `../path/to/my_model_directory/`.
  422. revision (`str`, *optional*):
  423. Revision of the model on the Hub. Can be a branch name, a git tag or any commit id.
  424. Defaults to the latest commit on `main` branch.
  425. force_download (`bool`, *optional*, defaults to `False`):
  426. Whether to force (re-)downloading the model weights and configuration files from the Hub, overriding
  427. the existing cache.
  428. proxies (`Dict[str, str]`, *optional*):
  429. A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
  430. 'http://hostname': 'foo.bar:4012'}`. The proxies are used on every request.
  431. token (`str` or `bool`, *optional*):
  432. The token to use as HTTP bearer authorization for remote files. By default, it will use the token
  433. cached when running `hf auth login`.
  434. cache_dir (`str`, `Path`, *optional*):
  435. Path to the folder where cached files are stored.
  436. local_files_only (`bool`, *optional*, defaults to `False`):
  437. If `True`, avoid downloading the file and return the path to the local cached file if it exists.
  438. model_kwargs (`Dict`, *optional*):
  439. Additional kwargs to pass to the model during initialization.
  440. """
  441. model_id = str(pretrained_model_name_or_path)
  442. config_file: Optional[str] = None
  443. if os.path.isdir(model_id):
  444. if constants.CONFIG_NAME in os.listdir(model_id):
  445. config_file = os.path.join(model_id, constants.CONFIG_NAME)
  446. else:
  447. logger.warning(f"{constants.CONFIG_NAME} not found in {Path(model_id).resolve()}")
  448. else:
  449. try:
  450. config_file = hf_hub_download(
  451. repo_id=model_id,
  452. filename=constants.CONFIG_NAME,
  453. revision=revision,
  454. cache_dir=cache_dir,
  455. force_download=force_download,
  456. proxies=proxies,
  457. resume_download=resume_download,
  458. token=token,
  459. local_files_only=local_files_only,
  460. )
  461. except HfHubHTTPError as e:
  462. logger.info(f"{constants.CONFIG_NAME} not found on the HuggingFace Hub: {str(e)}")
  463. # Read config
  464. config = None
  465. if config_file is not None:
  466. with open(config_file, "r", encoding="utf-8") as f:
  467. config = json.load(f)
  468. # Decode custom types in config
  469. for key, value in config.items():
  470. if key in cls._hub_mixin_init_parameters:
  471. expected_type = cls._hub_mixin_init_parameters[key].annotation
  472. if expected_type is not inspect.Parameter.empty:
  473. config[key] = cls._decode_arg(expected_type, value)
  474. # Populate model_kwargs from config
  475. for param in cls._hub_mixin_init_parameters.values():
  476. if param.name not in model_kwargs and param.name in config:
  477. model_kwargs[param.name] = config[param.name]
  478. # Check if `config` argument was passed at init
  479. if "config" in cls._hub_mixin_init_parameters and "config" not in model_kwargs:
  480. # Decode `config` argument if it was passed
  481. config_annotation = cls._hub_mixin_init_parameters["config"].annotation
  482. config = cls._decode_arg(config_annotation, config)
  483. # Forward config to model initialization
  484. model_kwargs["config"] = config
  485. # Inject config if `**kwargs` are expected
  486. if is_dataclass(cls):
  487. for key in cls.__dataclass_fields__:
  488. if key not in model_kwargs and key in config:
  489. model_kwargs[key] = config[key]
  490. elif any(param.kind == inspect.Parameter.VAR_KEYWORD for param in cls._hub_mixin_init_parameters.values()):
  491. for key, value in config.items():
  492. if key not in model_kwargs:
  493. model_kwargs[key] = value
  494. # Finally, also inject if `_from_pretrained` expects it
  495. if cls._hub_mixin_inject_config and "config" not in model_kwargs:
  496. model_kwargs["config"] = config
  497. instance = cls._from_pretrained(
  498. model_id=str(model_id),
  499. revision=revision,
  500. cache_dir=cache_dir,
  501. force_download=force_download,
  502. proxies=proxies,
  503. resume_download=resume_download,
  504. local_files_only=local_files_only,
  505. token=token,
  506. **model_kwargs,
  507. )
  508. # Implicitly set the config as instance attribute if not already set by the class
  509. # This way `config` will be available when calling `save_pretrained` or `push_to_hub`.
  510. if config is not None and (getattr(instance, "_hub_mixin_config", None) in (None, {})):
  511. instance._hub_mixin_config = config
  512. return instance
  513. @classmethod
  514. def _from_pretrained(
  515. cls: Type[T],
  516. *,
  517. model_id: str,
  518. revision: Optional[str],
  519. cache_dir: Optional[Union[str, Path]],
  520. force_download: bool,
  521. proxies: Optional[Dict],
  522. resume_download: Optional[bool],
  523. local_files_only: bool,
  524. token: Optional[Union[str, bool]],
  525. **model_kwargs,
  526. ) -> T:
  527. """Overwrite this method in subclass to define how to load your model from pretrained.
  528. Use [`hf_hub_download`] or [`snapshot_download`] to download files from the Hub before loading them. Most
  529. args taken as input can be directly passed to those 2 methods. If needed, you can add more arguments to this
  530. method using "model_kwargs". For example [`PyTorchModelHubMixin._from_pretrained`] takes as input a `map_location`
  531. parameter to set on which device the model should be loaded.
  532. Check out our [integration guide](../guides/integrations) for more instructions.
  533. Args:
  534. model_id (`str`):
  535. ID of the model to load from the Huggingface Hub (e.g. `bigscience/bloom`).
  536. revision (`str`, *optional*):
  537. Revision of the model on the Hub. Can be a branch name, a git tag or any commit id. Defaults to the
  538. latest commit on `main` branch.
  539. force_download (`bool`, *optional*, defaults to `False`):
  540. Whether to force (re-)downloading the model weights and configuration files from the Hub, overriding
  541. the existing cache.
  542. proxies (`Dict[str, str]`, *optional*):
  543. A dictionary of proxy servers to use by protocol or endpoint (e.g., `{'http': 'foo.bar:3128',
  544. 'http://hostname': 'foo.bar:4012'}`).
  545. token (`str` or `bool`, *optional*):
  546. The token to use as HTTP bearer authorization for remote files. By default, it will use the token
  547. cached when running `hf auth login`.
  548. cache_dir (`str`, `Path`, *optional*):
  549. Path to the folder where cached files are stored.
  550. local_files_only (`bool`, *optional*, defaults to `False`):
  551. If `True`, avoid downloading the file and return the path to the local cached file if it exists.
  552. model_kwargs:
  553. Additional keyword arguments passed along to the [`~ModelHubMixin._from_pretrained`] method.
  554. """
  555. raise NotImplementedError
  556. @validate_hf_hub_args
  557. def push_to_hub(
  558. self,
  559. repo_id: str,
  560. *,
  561. config: Optional[Union[dict, DataclassInstance]] = None,
  562. commit_message: str = "Push model using huggingface_hub.",
  563. private: Optional[bool] = None,
  564. token: Optional[str] = None,
  565. branch: Optional[str] = None,
  566. create_pr: Optional[bool] = None,
  567. allow_patterns: Optional[Union[List[str], str]] = None,
  568. ignore_patterns: Optional[Union[List[str], str]] = None,
  569. delete_patterns: Optional[Union[List[str], str]] = None,
  570. model_card_kwargs: Optional[Dict[str, Any]] = None,
  571. ) -> str:
  572. """
  573. Upload model checkpoint to the Hub.
  574. Use `allow_patterns` and `ignore_patterns` to precisely filter which files should be pushed to the hub. Use
  575. `delete_patterns` to delete existing remote files in the same commit. See [`upload_folder`] reference for more
  576. details.
  577. Args:
  578. repo_id (`str`):
  579. ID of the repository to push to (example: `"username/my-model"`).
  580. config (`dict` or `DataclassInstance`, *optional*):
  581. Model configuration specified as a key/value dictionary or a dataclass instance.
  582. commit_message (`str`, *optional*):
  583. Message to commit while pushing.
  584. private (`bool`, *optional*):
  585. Whether the repository created should be private.
  586. If `None` (default), the repo will be public unless the organization's default is private.
  587. token (`str`, *optional*):
  588. The token to use as HTTP bearer authorization for remote files. By default, it will use the token
  589. cached when running `hf auth login`.
  590. branch (`str`, *optional*):
  591. The git branch on which to push the model. This defaults to `"main"`.
  592. create_pr (`boolean`, *optional*):
  593. Whether or not to create a Pull Request from `branch` with that commit. Defaults to `False`.
  594. allow_patterns (`List[str]` or `str`, *optional*):
  595. If provided, only files matching at least one pattern are pushed.
  596. ignore_patterns (`List[str]` or `str`, *optional*):
  597. If provided, files matching any of the patterns are not pushed.
  598. delete_patterns (`List[str]` or `str`, *optional*):
  599. If provided, remote files matching any of the patterns will be deleted from the repo.
  600. model_card_kwargs (`Dict[str, Any]`, *optional*):
  601. Additional arguments passed to the model card template to customize the model card.
  602. Returns:
  603. The url of the commit of your model in the given repository.
  604. """
  605. api = HfApi(token=token)
  606. repo_id = api.create_repo(repo_id=repo_id, private=private, exist_ok=True).repo_id
  607. # Push the files to the repo in a single commit
  608. with SoftTemporaryDirectory() as tmp:
  609. saved_path = Path(tmp) / repo_id
  610. self.save_pretrained(saved_path, config=config, model_card_kwargs=model_card_kwargs)
  611. return api.upload_folder(
  612. repo_id=repo_id,
  613. repo_type="model",
  614. folder_path=saved_path,
  615. commit_message=commit_message,
  616. revision=branch,
  617. create_pr=create_pr,
  618. allow_patterns=allow_patterns,
  619. ignore_patterns=ignore_patterns,
  620. delete_patterns=delete_patterns,
  621. )
  622. def generate_model_card(self, *args, **kwargs) -> ModelCard:
  623. card = ModelCard.from_template(
  624. card_data=self._hub_mixin_info.model_card_data,
  625. template_str=self._hub_mixin_info.model_card_template,
  626. repo_url=self._hub_mixin_info.repo_url,
  627. paper_url=self._hub_mixin_info.paper_url,
  628. docs_url=self._hub_mixin_info.docs_url,
  629. **kwargs,
  630. )
  631. return card
  632. class PyTorchModelHubMixin(ModelHubMixin):
  633. """
  634. Implementation of [`ModelHubMixin`] to provide model Hub upload/download capabilities to PyTorch models. The model
  635. is set in evaluation mode by default using `model.eval()` (dropout modules are deactivated). To train the model,
  636. you should first set it back in training mode with `model.train()`.
  637. See [`ModelHubMixin`] for more details on how to use the mixin.
  638. Example:
  639. ```python
  640. >>> import torch
  641. >>> import torch.nn as nn
  642. >>> from huggingface_hub import PyTorchModelHubMixin
  643. >>> class MyModel(
  644. ... nn.Module,
  645. ... PyTorchModelHubMixin,
  646. ... library_name="keras-nlp",
  647. ... repo_url="https://github.com/keras-team/keras-nlp",
  648. ... paper_url="https://arxiv.org/abs/2304.12244",
  649. ... docs_url="https://keras.io/keras_nlp/",
  650. ... # ^ optional metadata to generate model card
  651. ... ):
  652. ... def __init__(self, hidden_size: int = 512, vocab_size: int = 30000, output_size: int = 4):
  653. ... super().__init__()
  654. ... self.param = nn.Parameter(torch.rand(hidden_size, vocab_size))
  655. ... self.linear = nn.Linear(output_size, vocab_size)
  656. ... def forward(self, x):
  657. ... return self.linear(x + self.param)
  658. >>> model = MyModel(hidden_size=256)
  659. # Save model weights to local directory
  660. >>> model.save_pretrained("my-awesome-model")
  661. # Push model weights to the Hub
  662. >>> model.push_to_hub("my-awesome-model")
  663. # Download and initialize weights from the Hub
  664. >>> model = MyModel.from_pretrained("username/my-awesome-model")
  665. >>> model.hidden_size
  666. 256
  667. ```
  668. """
  669. def __init_subclass__(cls, *args, tags: Optional[List[str]] = None, **kwargs) -> None:
  670. tags = tags or []
  671. tags.append("pytorch_model_hub_mixin")
  672. kwargs["tags"] = tags
  673. return super().__init_subclass__(*args, **kwargs)
  674. def _save_pretrained(self, save_directory: Path) -> None:
  675. """Save weights from a Pytorch model to a local directory."""
  676. model_to_save = self.module if hasattr(self, "module") else self # type: ignore
  677. save_model_as_safetensor(model_to_save, str(save_directory / constants.SAFETENSORS_SINGLE_FILE)) # type: ignore [arg-type]
  678. @classmethod
  679. def _from_pretrained(
  680. cls,
  681. *,
  682. model_id: str,
  683. revision: Optional[str],
  684. cache_dir: Optional[Union[str, Path]],
  685. force_download: bool,
  686. proxies: Optional[Dict],
  687. resume_download: Optional[bool],
  688. local_files_only: bool,
  689. token: Union[str, bool, None],
  690. map_location: str = "cpu",
  691. strict: bool = False,
  692. **model_kwargs,
  693. ):
  694. """Load Pytorch pretrained weights and return the loaded model."""
  695. model = cls(**model_kwargs)
  696. if os.path.isdir(model_id):
  697. print("Loading weights from local directory")
  698. model_file = os.path.join(model_id, constants.SAFETENSORS_SINGLE_FILE)
  699. return cls._load_as_safetensor(model, model_file, map_location, strict)
  700. else:
  701. try:
  702. model_file = hf_hub_download(
  703. repo_id=model_id,
  704. filename=constants.SAFETENSORS_SINGLE_FILE,
  705. revision=revision,
  706. cache_dir=cache_dir,
  707. force_download=force_download,
  708. proxies=proxies,
  709. resume_download=resume_download,
  710. token=token,
  711. local_files_only=local_files_only,
  712. )
  713. return cls._load_as_safetensor(model, model_file, map_location, strict)
  714. except EntryNotFoundError:
  715. model_file = hf_hub_download(
  716. repo_id=model_id,
  717. filename=constants.PYTORCH_WEIGHTS_NAME,
  718. revision=revision,
  719. cache_dir=cache_dir,
  720. force_download=force_download,
  721. proxies=proxies,
  722. resume_download=resume_download,
  723. token=token,
  724. local_files_only=local_files_only,
  725. )
  726. return cls._load_as_pickle(model, model_file, map_location, strict)
  727. @classmethod
  728. def _load_as_pickle(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
  729. state_dict = torch.load(model_file, map_location=torch.device(map_location), weights_only=True)
  730. model.load_state_dict(state_dict, strict=strict) # type: ignore
  731. model.eval() # type: ignore
  732. return model
  733. @classmethod
  734. def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
  735. if packaging.version.parse(safetensors.__version__) < packaging.version.parse("0.4.3"): # type: ignore [attr-defined]
  736. load_model_as_safetensor(model, model_file, strict=strict) # type: ignore [arg-type]
  737. if map_location != "cpu":
  738. logger.warning(
  739. "Loading model weights on other devices than 'cpu' is not supported natively in your version of safetensors."
  740. " This means that the model is loaded on 'cpu' first and then copied to the device."
  741. " This leads to a slower loading time."
  742. " Please update safetensors to version 0.4.3 or above for improved performance."
  743. )
  744. model.to(map_location) # type: ignore [attr-defined]
  745. else:
  746. safetensors.torch.load_model(model, model_file, strict=strict, device=map_location) # type: ignore [arg-type]
  747. return model
  748. def _load_dataclass(datacls: Type[DataclassInstance], data: dict) -> DataclassInstance:
  749. """Load a dataclass instance from a dictionary.
  750. Fields not expected by the dataclass are ignored.
  751. """
  752. return datacls(**{k: v for k, v in data.items() if k in datacls.__dataclass_fields__})