dataclasses.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615
  1. import inspect
  2. import sys
  3. import types
  4. from dataclasses import _MISSING_TYPE, MISSING, Field, field, fields, make_dataclass
  5. from functools import lru_cache, wraps
  6. from typing import (
  7. Annotated,
  8. Any,
  9. Callable,
  10. ForwardRef,
  11. Literal,
  12. Optional,
  13. Type,
  14. TypeVar,
  15. Union,
  16. get_args,
  17. get_origin,
  18. overload,
  19. )
  20. try:
  21. # Python 3.11+
  22. from typing import NotRequired, Required # type: ignore
  23. except ImportError:
  24. try:
  25. # In case typing_extensions is installed
  26. from typing_extensions import NotRequired, Required # type: ignore
  27. except ImportError:
  28. # Fallback: create dummy types that will never match
  29. Required = type("Required", (), {}) # type: ignore
  30. NotRequired = type("NotRequired", (), {}) # type: ignore
  31. from .errors import (
  32. StrictDataclassClassValidationError,
  33. StrictDataclassDefinitionError,
  34. StrictDataclassFieldValidationError,
  35. )
  36. Validator_T = Callable[[Any], None]
  37. T = TypeVar("T")
  38. TypedDictType = TypeVar("TypedDictType", bound=dict[str, Any])
  39. _TYPED_DICT_DEFAULT_VALUE = object() # used as default value in TypedDict fields (to distinguish from None)
  40. # The overload decorator helps type checkers understand the different return types
  41. @overload
  42. def strict(cls: Type[T]) -> Type[T]: ...
  43. @overload
  44. def strict(*, accept_kwargs: bool = False) -> Callable[[Type[T]], Type[T]]: ...
  45. def strict(
  46. cls: Optional[Type[T]] = None, *, accept_kwargs: bool = False
  47. ) -> Union[Type[T], Callable[[Type[T]], Type[T]]]:
  48. """
  49. Decorator to add strict validation to a dataclass.
  50. This decorator must be used on top of `@dataclass` to ensure IDEs and static typing tools
  51. recognize the class as a dataclass.
  52. Can be used with or without arguments:
  53. - `@strict`
  54. - `@strict(accept_kwargs=True)`
  55. Args:
  56. cls:
  57. The class to convert to a strict dataclass.
  58. accept_kwargs (`bool`, *optional*):
  59. If True, allows arbitrary keyword arguments in `__init__`. Defaults to False.
  60. Returns:
  61. The enhanced dataclass with strict validation on field assignment.
  62. Example:
  63. ```py
  64. >>> from dataclasses import dataclass
  65. >>> from huggingface_hub.dataclasses import as_validated_field, strict, validated_field
  66. >>> @as_validated_field
  67. >>> def positive_int(value: int):
  68. ... if not value >= 0:
  69. ... raise ValueError(f"Value must be positive, got {value}")
  70. >>> @strict(accept_kwargs=True)
  71. ... @dataclass
  72. ... class User:
  73. ... name: str
  74. ... age: int = positive_int(default=10)
  75. # Initialize
  76. >>> User(name="John")
  77. User(name='John', age=10)
  78. # Extra kwargs are accepted
  79. >>> User(name="John", age=30, lastname="Doe")
  80. User(name='John', age=30, *lastname='Doe')
  81. # Invalid type => raises
  82. >>> User(name="John", age="30")
  83. huggingface_hub.errors.StrictDataclassFieldValidationError: Validation error for field 'age':
  84. TypeError: Field 'age' expected int, got str (value: '30')
  85. # Invalid value => raises
  86. >>> User(name="John", age=-1)
  87. huggingface_hub.errors.StrictDataclassFieldValidationError: Validation error for field 'age':
  88. ValueError: Value must be positive, got -1
  89. ```
  90. """
  91. def wrap(cls: Type[T]) -> Type[T]:
  92. if not hasattr(cls, "__dataclass_fields__"):
  93. raise StrictDataclassDefinitionError(
  94. f"Class '{cls.__name__}' must be a dataclass before applying @strict."
  95. )
  96. # List and store validators
  97. field_validators: dict[str, list[Validator_T]] = {}
  98. for f in fields(cls): # type: ignore [arg-type]
  99. validators = []
  100. validators.append(_create_type_validator(f))
  101. custom_validator = f.metadata.get("validator")
  102. if custom_validator is not None:
  103. if not isinstance(custom_validator, list):
  104. custom_validator = [custom_validator]
  105. for validator in custom_validator:
  106. if not _is_validator(validator):
  107. raise StrictDataclassDefinitionError(
  108. f"Invalid validator for field '{f.name}': {validator}. Must be a callable taking a single argument."
  109. )
  110. validators.extend(custom_validator)
  111. field_validators[f.name] = validators
  112. cls.__validators__ = field_validators # type: ignore
  113. # Override __setattr__ to validate fields on assignment
  114. original_setattr = cls.__setattr__
  115. def __strict_setattr__(self: Any, name: str, value: Any) -> None:
  116. """Custom __setattr__ method for strict dataclasses."""
  117. # Run all validators
  118. for validator in self.__validators__.get(name, []):
  119. try:
  120. validator(value)
  121. except (ValueError, TypeError) as e:
  122. raise StrictDataclassFieldValidationError(field=name, cause=e) from e
  123. # If validation passed, set the attribute
  124. original_setattr(self, name, value)
  125. cls.__setattr__ = __strict_setattr__ # type: ignore[method-assign]
  126. if accept_kwargs:
  127. # (optional) Override __init__ to accept arbitrary keyword arguments
  128. original_init = cls.__init__
  129. @wraps(original_init)
  130. def __init__(self, **kwargs: Any) -> None:
  131. # Extract only the fields that are part of the dataclass
  132. dataclass_fields = {f.name for f in fields(cls)} # type: ignore [arg-type]
  133. standard_kwargs = {k: v for k, v in kwargs.items() if k in dataclass_fields}
  134. # Call the original __init__ with standard fields
  135. original_init(self, **standard_kwargs)
  136. # Add any additional kwargs as attributes
  137. for name, value in kwargs.items():
  138. if name not in dataclass_fields:
  139. self.__setattr__(name, value)
  140. cls.__init__ = __init__ # type: ignore[method-assign]
  141. # (optional) Override __repr__ to include additional kwargs
  142. original_repr = cls.__repr__
  143. @wraps(original_repr)
  144. def __repr__(self) -> str:
  145. # Call the original __repr__ to get the standard fields
  146. standard_repr = original_repr(self)
  147. # Get additional kwargs
  148. additional_kwargs = [
  149. # add a '*' in front of additional kwargs to let the user know they are not part of the dataclass
  150. f"*{k}={v!r}"
  151. for k, v in self.__dict__.items()
  152. if k not in cls.__dataclass_fields__ # type: ignore [attr-defined]
  153. ]
  154. additional_repr = ", ".join(additional_kwargs)
  155. # Combine both representations
  156. return f"{standard_repr[:-1]}, {additional_repr})" if additional_kwargs else standard_repr
  157. cls.__repr__ = __repr__ # type: ignore [method-assign]
  158. # List all public methods starting with `validate_` => class validators.
  159. class_validators = []
  160. for name in dir(cls):
  161. if not name.startswith("validate_"):
  162. continue
  163. method = getattr(cls, name)
  164. if not callable(method):
  165. continue
  166. if len(inspect.signature(method).parameters) != 1:
  167. raise StrictDataclassDefinitionError(
  168. f"Class '{cls.__name__}' has a class validator '{name}' that takes more than one argument."
  169. " Class validators must take only 'self' as an argument. Methods starting with 'validate_'"
  170. " are considered to be class validators."
  171. )
  172. class_validators.append(method)
  173. cls.__class_validators__ = class_validators # type: ignore [attr-defined]
  174. # Add `validate` method to the class, but first check if it already exists
  175. def validate(self: T) -> None:
  176. """Run class validators on the instance."""
  177. for validator in cls.__class_validators__: # type: ignore [attr-defined]
  178. try:
  179. validator(self)
  180. except (ValueError, TypeError) as e:
  181. raise StrictDataclassClassValidationError(validator=validator.__name__, cause=e) from e
  182. # Hack to be able to raise if `.validate()` already exists except if it was created by this decorator on a parent class
  183. # (in which case we just override it)
  184. validate.__is_defined_by_strict_decorator__ = True # type: ignore [attr-defined]
  185. if hasattr(cls, "validate"):
  186. if not getattr(cls.validate, "__is_defined_by_strict_decorator__", False): # type: ignore [attr-defined]
  187. raise StrictDataclassDefinitionError(
  188. f"Class '{cls.__name__}' already implements a method called 'validate'."
  189. " This method name is reserved when using the @strict decorator on a dataclass."
  190. " If you want to keep your own method, please rename it."
  191. )
  192. cls.validate = validate # type: ignore
  193. # Run class validators after initialization
  194. initial_init = cls.__init__
  195. @wraps(initial_init)
  196. def init_with_validate(self, *args, **kwargs) -> None:
  197. """Run class validators after initialization."""
  198. initial_init(self, *args, **kwargs) # type: ignore [call-arg]
  199. cls.validate(self) # type: ignore [attr-defined]
  200. setattr(cls, "__init__", init_with_validate)
  201. return cls
  202. # Return wrapped class or the decorator itself
  203. return wrap(cls) if cls is not None else wrap
  204. def validate_typed_dict(schema: type[TypedDictType], data: dict) -> None:
  205. """
  206. Validate that a dictionary conforms to the types defined in a TypedDict class.
  207. Under the hood, the typed dict is converted to a strict dataclass and validated using the `@strict` decorator.
  208. Args:
  209. schema (`type[TypedDictType]`):
  210. The TypedDict class defining the expected structure and types.
  211. data (`dict`):
  212. The dictionary to validate.
  213. Raises:
  214. `StrictDataclassFieldValidationError`:
  215. If any field in the dictionary does not conform to the expected type.
  216. Example:
  217. ```py
  218. >>> from typing import Annotated, TypedDict
  219. >>> from huggingface_hub.dataclasses import validate_typed_dict
  220. >>> def positive_int(value: int):
  221. ... if not value >= 0:
  222. ... raise ValueError(f"Value must be positive, got {value}")
  223. >>> class User(TypedDict):
  224. ... name: str
  225. ... age: Annotated[int, positive_int]
  226. >>> # Valid data
  227. >>> validate_typed_dict(User, {"name": "John", "age": 30})
  228. >>> # Invalid type for age
  229. >>> validate_typed_dict(User, {"name": "John", "age": "30"})
  230. huggingface_hub.errors.StrictDataclassFieldValidationError: Validation error for field 'age':
  231. TypeError: Field 'age' expected int, got str (value: '30')
  232. >>> # Invalid value for age
  233. >>> validate_typed_dict(User, {"name": "John", "age": -1})
  234. huggingface_hub.errors.StrictDataclassFieldValidationError: Validation error for field 'age':
  235. ValueError: Value must be positive, got -1
  236. ```
  237. """
  238. # Convert typed dict to dataclass
  239. strict_cls = _build_strict_cls_from_typed_dict(schema)
  240. # Validate the data by instantiating the strict dataclass
  241. strict_cls(**data) # will raise if validation fails
  242. @lru_cache
  243. def _build_strict_cls_from_typed_dict(schema: type[TypedDictType]) -> Type:
  244. # Extract type hints from the TypedDict class
  245. type_hints = _get_typed_dict_annotations(schema)
  246. # If the TypedDict is not total, wrap fields as NotRequired (unless explicitly Required or NotRequired)
  247. if not getattr(schema, "__total__", True):
  248. for key, value in type_hints.items():
  249. origin = get_origin(value)
  250. if origin is Annotated:
  251. base, *meta = get_args(value)
  252. if not _is_required_or_notrequired(base):
  253. base = NotRequired[base]
  254. type_hints[key] = Annotated[tuple([base] + list(meta))] # type: ignore
  255. elif not _is_required_or_notrequired(value):
  256. type_hints[key] = NotRequired[value]
  257. # Convert type hints to dataclass fields
  258. fields = []
  259. for key, value in type_hints.items():
  260. if get_origin(value) is Annotated:
  261. base, *meta = get_args(value)
  262. fields.append((key, base, field(default=_TYPED_DICT_DEFAULT_VALUE, metadata={"validator": meta[0]})))
  263. else:
  264. fields.append((key, value, field(default=_TYPED_DICT_DEFAULT_VALUE)))
  265. # Create a strict dataclass from the TypedDict fields
  266. return strict(make_dataclass(schema.__name__, fields))
  267. def _get_typed_dict_annotations(schema: type[TypedDictType]) -> dict[str, Any]:
  268. """Extract type annotations from a TypedDict class."""
  269. try:
  270. # Available in Python 3.14+
  271. import annotationlib
  272. return annotationlib.get_annotations(schema)
  273. except ImportError:
  274. return {
  275. # We do not use `get_type_hints` here to avoid evaluating ForwardRefs (which might fail).
  276. # ForwardRefs are not validated by @strict anyway.
  277. name: value if value is not None else type(None)
  278. for name, value in schema.__dict__.get("__annotations__", {}).items()
  279. }
  280. def validated_field(
  281. validator: Union[list[Validator_T], Validator_T],
  282. default: Union[Any, _MISSING_TYPE] = MISSING,
  283. default_factory: Union[Callable[[], Any], _MISSING_TYPE] = MISSING,
  284. init: bool = True,
  285. repr: bool = True,
  286. hash: Optional[bool] = None,
  287. compare: bool = True,
  288. metadata: Optional[dict] = None,
  289. **kwargs: Any,
  290. ) -> Any:
  291. """
  292. Create a dataclass field with a custom validator.
  293. Useful to apply several checks to a field. If only applying one rule, check out the [`as_validated_field`] decorator.
  294. Args:
  295. validator (`Callable` or `list[Callable]`):
  296. A method that takes a value as input and raises ValueError/TypeError if the value is invalid.
  297. Can be a list of validators to apply multiple checks.
  298. **kwargs:
  299. Additional arguments to pass to `dataclasses.field()`.
  300. Returns:
  301. A field with the validator attached in metadata
  302. """
  303. if not isinstance(validator, list):
  304. validator = [validator]
  305. if metadata is None:
  306. metadata = {}
  307. metadata["validator"] = validator
  308. return field( # type: ignore
  309. default=default, # type: ignore [arg-type]
  310. default_factory=default_factory, # type: ignore [arg-type]
  311. init=init,
  312. repr=repr,
  313. hash=hash,
  314. compare=compare,
  315. metadata=metadata,
  316. **kwargs,
  317. )
  318. def as_validated_field(validator: Validator_T):
  319. """
  320. Decorates a validator function as a [`validated_field`] (i.e. a dataclass field with a custom validator).
  321. Args:
  322. validator (`Callable`):
  323. A method that takes a value as input and raises ValueError/TypeError if the value is invalid.
  324. """
  325. def _inner(
  326. default: Union[Any, _MISSING_TYPE] = MISSING,
  327. default_factory: Union[Callable[[], Any], _MISSING_TYPE] = MISSING,
  328. init: bool = True,
  329. repr: bool = True,
  330. hash: Optional[bool] = None,
  331. compare: bool = True,
  332. metadata: Optional[dict] = None,
  333. **kwargs: Any,
  334. ):
  335. return validated_field(
  336. validator,
  337. default=default,
  338. default_factory=default_factory,
  339. init=init,
  340. repr=repr,
  341. hash=hash,
  342. compare=compare,
  343. metadata=metadata,
  344. **kwargs,
  345. )
  346. return _inner
  347. def type_validator(name: str, value: Any, expected_type: Any) -> None:
  348. """Validate that 'value' matches 'expected_type'."""
  349. origin = get_origin(expected_type)
  350. args = get_args(expected_type)
  351. if expected_type is Any:
  352. return
  353. elif validator := _BASIC_TYPE_VALIDATORS.get(origin):
  354. validator(name, value, args)
  355. elif isinstance(expected_type, type): # simple types
  356. _validate_simple_type(name, value, expected_type)
  357. elif isinstance(expected_type, ForwardRef) or isinstance(expected_type, str):
  358. return
  359. elif origin is Required:
  360. if value is _TYPED_DICT_DEFAULT_VALUE:
  361. raise TypeError(f"Field '{name}' is required but missing.")
  362. type_validator(name, value, args[0])
  363. elif origin is NotRequired:
  364. if value is _TYPED_DICT_DEFAULT_VALUE:
  365. return
  366. type_validator(name, value, args[0])
  367. else:
  368. raise TypeError(f"Unsupported type for field '{name}': {expected_type}")
  369. def _validate_union(name: str, value: Any, args: tuple[Any, ...]) -> None:
  370. """Validate that value matches one of the types in a Union."""
  371. errors = []
  372. for t in args:
  373. try:
  374. type_validator(name, value, t)
  375. return # Valid if any type matches
  376. except TypeError as e:
  377. errors.append(str(e))
  378. raise TypeError(
  379. f"Field '{name}' with value {repr(value)} doesn't match any type in {args}. Errors: {'; '.join(errors)}"
  380. )
  381. def _validate_literal(name: str, value: Any, args: tuple[Any, ...]) -> None:
  382. """Validate Literal type."""
  383. if value not in args:
  384. raise TypeError(f"Field '{name}' expected one of {args}, got {value}")
  385. def _validate_list(name: str, value: Any, args: tuple[Any, ...]) -> None:
  386. """Validate list[T] type."""
  387. if not isinstance(value, list):
  388. raise TypeError(f"Field '{name}' expected a list, got {type(value).__name__}")
  389. # Validate each item in the list
  390. item_type = args[0]
  391. for i, item in enumerate(value):
  392. try:
  393. type_validator(f"{name}[{i}]", item, item_type)
  394. except TypeError as e:
  395. raise TypeError(f"Invalid item at index {i} in list '{name}'") from e
  396. def _validate_dict(name: str, value: Any, args: tuple[Any, ...]) -> None:
  397. """Validate dict[K, V] type."""
  398. if not isinstance(value, dict):
  399. raise TypeError(f"Field '{name}' expected a dict, got {type(value).__name__}")
  400. # Validate keys and values
  401. key_type, value_type = args
  402. for k, v in value.items():
  403. try:
  404. type_validator(f"{name}.key", k, key_type)
  405. type_validator(f"{name}[{k!r}]", v, value_type)
  406. except TypeError as e:
  407. raise TypeError(f"Invalid key or value in dict '{name}'") from e
  408. def _validate_tuple(name: str, value: Any, args: tuple[Any, ...]) -> None:
  409. """Validate Tuple type."""
  410. if not isinstance(value, tuple):
  411. raise TypeError(f"Field '{name}' expected a tuple, got {type(value).__name__}")
  412. # Handle variable-length tuples: tuple[T, ...]
  413. if len(args) == 2 and args[1] is Ellipsis:
  414. for i, item in enumerate(value):
  415. try:
  416. type_validator(f"{name}[{i}]", item, args[0])
  417. except TypeError as e:
  418. raise TypeError(f"Invalid item at index {i} in tuple '{name}'") from e
  419. # Handle fixed-length tuples: tuple[T1, T2, ...]
  420. elif len(args) != len(value):
  421. raise TypeError(f"Field '{name}' expected a tuple of length {len(args)}, got {len(value)}")
  422. else:
  423. for i, (item, expected) in enumerate(zip(value, args)):
  424. try:
  425. type_validator(f"{name}[{i}]", item, expected)
  426. except TypeError as e:
  427. raise TypeError(f"Invalid item at index {i} in tuple '{name}'") from e
  428. def _validate_set(name: str, value: Any, args: tuple[Any, ...]) -> None:
  429. """Validate set[T] type."""
  430. if not isinstance(value, set):
  431. raise TypeError(f"Field '{name}' expected a set, got {type(value).__name__}")
  432. # Validate each item in the set
  433. item_type = args[0]
  434. for i, item in enumerate(value):
  435. try:
  436. type_validator(f"{name} item", item, item_type)
  437. except TypeError as e:
  438. raise TypeError(f"Invalid item in set '{name}'") from e
  439. def _validate_simple_type(name: str, value: Any, expected_type: type) -> None:
  440. """Validate simple type (int, str, etc.)."""
  441. if not isinstance(value, expected_type):
  442. raise TypeError(
  443. f"Field '{name}' expected {expected_type.__name__}, got {type(value).__name__} (value: {repr(value)})"
  444. )
  445. def _create_type_validator(field: Field) -> Validator_T:
  446. """Create a type validator function for a field."""
  447. # Hacky: we cannot use a lambda here because of reference issues
  448. def validator(value: Any) -> None:
  449. type_validator(field.name, value, field.type)
  450. return validator
  451. def _is_validator(validator: Any) -> bool:
  452. """Check if a function is a validator.
  453. A validator is a Callable that can be called with a single positional argument.
  454. The validator can have more arguments with default values.
  455. Basically, returns True if `validator(value)` is possible.
  456. """
  457. if not callable(validator):
  458. return False
  459. signature = inspect.signature(validator)
  460. parameters = list(signature.parameters.values())
  461. if len(parameters) == 0:
  462. return False
  463. if parameters[0].kind not in (
  464. inspect.Parameter.POSITIONAL_OR_KEYWORD,
  465. inspect.Parameter.POSITIONAL_ONLY,
  466. inspect.Parameter.VAR_POSITIONAL,
  467. ):
  468. return False
  469. for parameter in parameters[1:]:
  470. if parameter.default == inspect.Parameter.empty:
  471. return False
  472. return True
  473. def _is_required_or_notrequired(type_hint: Any) -> bool:
  474. """Helper to check if a type is Required/NotRequired."""
  475. return type_hint in (Required, NotRequired) or (get_origin(type_hint) in (Required, NotRequired))
  476. _BASIC_TYPE_VALIDATORS = {
  477. Union: _validate_union,
  478. Literal: _validate_literal,
  479. list: _validate_list,
  480. dict: _validate_dict,
  481. tuple: _validate_tuple,
  482. set: _validate_set,
  483. }
  484. if sys.version_info >= (3, 10):
  485. # TODO: make it first class citizen when bumping to Python 3.10+
  486. _BASIC_TYPE_VALIDATORS[types.UnionType] = _validate_union # x | y syntax, available only Python 3.10+
  487. __all__ = [
  488. "strict",
  489. "validate_typed_dict",
  490. "validated_field",
  491. "Validator_T",
  492. "StrictDataclassClassValidationError",
  493. "StrictDataclassDefinitionError",
  494. "StrictDataclassFieldValidationError",
  495. ]