dataclasses.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484
  1. import inspect
  2. from dataclasses import _MISSING_TYPE, MISSING, Field, field, fields
  3. from functools import wraps
  4. from typing import (
  5. Any,
  6. Callable,
  7. Dict,
  8. ForwardRef,
  9. List,
  10. Literal,
  11. Optional,
  12. Tuple,
  13. Type,
  14. TypeVar,
  15. Union,
  16. get_args,
  17. get_origin,
  18. overload,
  19. )
  20. from .errors import (
  21. StrictDataclassClassValidationError,
  22. StrictDataclassDefinitionError,
  23. StrictDataclassFieldValidationError,
  24. )
  25. Validator_T = Callable[[Any], None]
  26. T = TypeVar("T")
  27. # The overload decorator helps type checkers understand the different return types
  28. @overload
  29. def strict(cls: Type[T]) -> Type[T]: ...
  30. @overload
  31. def strict(*, accept_kwargs: bool = False) -> Callable[[Type[T]], Type[T]]: ...
  32. def strict(
  33. cls: Optional[Type[T]] = None, *, accept_kwargs: bool = False
  34. ) -> Union[Type[T], Callable[[Type[T]], Type[T]]]:
  35. """
  36. Decorator to add strict validation to a dataclass.
  37. This decorator must be used on top of `@dataclass` to ensure IDEs and static typing tools
  38. recognize the class as a dataclass.
  39. Can be used with or without arguments:
  40. - `@strict`
  41. - `@strict(accept_kwargs=True)`
  42. Args:
  43. cls:
  44. The class to convert to a strict dataclass.
  45. accept_kwargs (`bool`, *optional*):
  46. If True, allows arbitrary keyword arguments in `__init__`. Defaults to False.
  47. Returns:
  48. The enhanced dataclass with strict validation on field assignment.
  49. Example:
  50. ```py
  51. >>> from dataclasses import dataclass
  52. >>> from huggingface_hub.dataclasses import as_validated_field, strict, validated_field
  53. >>> @as_validated_field
  54. >>> def positive_int(value: int):
  55. ... if not value >= 0:
  56. ... raise ValueError(f"Value must be positive, got {value}")
  57. >>> @strict(accept_kwargs=True)
  58. ... @dataclass
  59. ... class User:
  60. ... name: str
  61. ... age: int = positive_int(default=10)
  62. # Initialize
  63. >>> User(name="John")
  64. User(name='John', age=10)
  65. # Extra kwargs are accepted
  66. >>> User(name="John", age=30, lastname="Doe")
  67. User(name='John', age=30, *lastname='Doe')
  68. # Invalid type => raises
  69. >>> User(name="John", age="30")
  70. huggingface_hub.errors.StrictDataclassFieldValidationError: Validation error for field 'age':
  71. TypeError: Field 'age' expected int, got str (value: '30')
  72. # Invalid value => raises
  73. >>> User(name="John", age=-1)
  74. huggingface_hub.errors.StrictDataclassFieldValidationError: Validation error for field 'age':
  75. ValueError: Value must be positive, got -1
  76. ```
  77. """
  78. def wrap(cls: Type[T]) -> Type[T]:
  79. if not hasattr(cls, "__dataclass_fields__"):
  80. raise StrictDataclassDefinitionError(
  81. f"Class '{cls.__name__}' must be a dataclass before applying @strict."
  82. )
  83. # List and store validators
  84. field_validators: Dict[str, List[Validator_T]] = {}
  85. for f in fields(cls): # type: ignore [arg-type]
  86. validators = []
  87. validators.append(_create_type_validator(f))
  88. custom_validator = f.metadata.get("validator")
  89. if custom_validator is not None:
  90. if not isinstance(custom_validator, list):
  91. custom_validator = [custom_validator]
  92. for validator in custom_validator:
  93. if not _is_validator(validator):
  94. raise StrictDataclassDefinitionError(
  95. f"Invalid validator for field '{f.name}': {validator}. Must be a callable taking a single argument."
  96. )
  97. validators.extend(custom_validator)
  98. field_validators[f.name] = validators
  99. cls.__validators__ = field_validators # type: ignore
  100. # Override __setattr__ to validate fields on assignment
  101. original_setattr = cls.__setattr__
  102. def __strict_setattr__(self: Any, name: str, value: Any) -> None:
  103. """Custom __setattr__ method for strict dataclasses."""
  104. # Run all validators
  105. for validator in self.__validators__.get(name, []):
  106. try:
  107. validator(value)
  108. except (ValueError, TypeError) as e:
  109. raise StrictDataclassFieldValidationError(field=name, cause=e) from e
  110. # If validation passed, set the attribute
  111. original_setattr(self, name, value)
  112. cls.__setattr__ = __strict_setattr__ # type: ignore[method-assign]
  113. if accept_kwargs:
  114. # (optional) Override __init__ to accept arbitrary keyword arguments
  115. original_init = cls.__init__
  116. @wraps(original_init)
  117. def __init__(self, **kwargs: Any) -> None:
  118. # Extract only the fields that are part of the dataclass
  119. dataclass_fields = {f.name for f in fields(cls)} # type: ignore [arg-type]
  120. standard_kwargs = {k: v for k, v in kwargs.items() if k in dataclass_fields}
  121. # Call the original __init__ with standard fields
  122. original_init(self, **standard_kwargs)
  123. # Add any additional kwargs as attributes
  124. for name, value in kwargs.items():
  125. if name not in dataclass_fields:
  126. self.__setattr__(name, value)
  127. cls.__init__ = __init__ # type: ignore[method-assign]
  128. # (optional) Override __repr__ to include additional kwargs
  129. original_repr = cls.__repr__
  130. @wraps(original_repr)
  131. def __repr__(self) -> str:
  132. # Call the original __repr__ to get the standard fields
  133. standard_repr = original_repr(self)
  134. # Get additional kwargs
  135. additional_kwargs = [
  136. # add a '*' in front of additional kwargs to let the user know they are not part of the dataclass
  137. f"*{k}={v!r}"
  138. for k, v in self.__dict__.items()
  139. if k not in cls.__dataclass_fields__ # type: ignore [attr-defined]
  140. ]
  141. additional_repr = ", ".join(additional_kwargs)
  142. # Combine both representations
  143. return f"{standard_repr[:-1]}, {additional_repr})" if additional_kwargs else standard_repr
  144. cls.__repr__ = __repr__ # type: ignore [method-assign]
  145. # List all public methods starting with `validate_` => class validators.
  146. class_validators = []
  147. for name in dir(cls):
  148. if not name.startswith("validate_"):
  149. continue
  150. method = getattr(cls, name)
  151. if not callable(method):
  152. continue
  153. if len(inspect.signature(method).parameters) != 1:
  154. raise StrictDataclassDefinitionError(
  155. f"Class '{cls.__name__}' has a class validator '{name}' that takes more than one argument."
  156. " Class validators must take only 'self' as an argument. Methods starting with 'validate_'"
  157. " are considered to be class validators."
  158. )
  159. class_validators.append(method)
  160. cls.__class_validators__ = class_validators # type: ignore [attr-defined]
  161. # Add `validate` method to the class, but first check if it already exists
  162. def validate(self: T) -> None:
  163. """Run class validators on the instance."""
  164. for validator in cls.__class_validators__: # type: ignore [attr-defined]
  165. try:
  166. validator(self)
  167. except (ValueError, TypeError) as e:
  168. raise StrictDataclassClassValidationError(validator=validator.__name__, cause=e) from e
  169. # Hack to be able to raise if `.validate()` already exists except if it was created by this decorator on a parent class
  170. # (in which case we just override it)
  171. validate.__is_defined_by_strict_decorator__ = True # type: ignore [attr-defined]
  172. if hasattr(cls, "validate"):
  173. if not getattr(cls.validate, "__is_defined_by_strict_decorator__", False): # type: ignore [attr-defined]
  174. raise StrictDataclassDefinitionError(
  175. f"Class '{cls.__name__}' already implements a method called 'validate'."
  176. " This method name is reserved when using the @strict decorator on a dataclass."
  177. " If you want to keep your own method, please rename it."
  178. )
  179. cls.validate = validate # type: ignore
  180. # Run class validators after initialization
  181. initial_init = cls.__init__
  182. @wraps(initial_init)
  183. def init_with_validate(self, *args, **kwargs) -> None:
  184. """Run class validators after initialization."""
  185. initial_init(self, *args, **kwargs) # type: ignore [call-arg]
  186. cls.validate(self) # type: ignore [attr-defined]
  187. setattr(cls, "__init__", init_with_validate)
  188. return cls
  189. # Return wrapped class or the decorator itself
  190. return wrap(cls) if cls is not None else wrap
  191. def validated_field(
  192. validator: Union[List[Validator_T], Validator_T],
  193. default: Union[Any, _MISSING_TYPE] = MISSING,
  194. default_factory: Union[Callable[[], Any], _MISSING_TYPE] = MISSING,
  195. init: bool = True,
  196. repr: bool = True,
  197. hash: Optional[bool] = None,
  198. compare: bool = True,
  199. metadata: Optional[Dict] = None,
  200. **kwargs: Any,
  201. ) -> Any:
  202. """
  203. Create a dataclass field with a custom validator.
  204. Useful to apply several checks to a field. If only applying one rule, check out the [`as_validated_field`] decorator.
  205. Args:
  206. validator (`Callable` or `List[Callable]`):
  207. A method that takes a value as input and raises ValueError/TypeError if the value is invalid.
  208. Can be a list of validators to apply multiple checks.
  209. **kwargs:
  210. Additional arguments to pass to `dataclasses.field()`.
  211. Returns:
  212. A field with the validator attached in metadata
  213. """
  214. if not isinstance(validator, list):
  215. validator = [validator]
  216. if metadata is None:
  217. metadata = {}
  218. metadata["validator"] = validator
  219. return field( # type: ignore
  220. default=default, # type: ignore [arg-type]
  221. default_factory=default_factory, # type: ignore [arg-type]
  222. init=init,
  223. repr=repr,
  224. hash=hash,
  225. compare=compare,
  226. metadata=metadata,
  227. **kwargs,
  228. )
  229. def as_validated_field(validator: Validator_T):
  230. """
  231. Decorates a validator function as a [`validated_field`] (i.e. a dataclass field with a custom validator).
  232. Args:
  233. validator (`Callable`):
  234. A method that takes a value as input and raises ValueError/TypeError if the value is invalid.
  235. """
  236. def _inner(
  237. default: Union[Any, _MISSING_TYPE] = MISSING,
  238. default_factory: Union[Callable[[], Any], _MISSING_TYPE] = MISSING,
  239. init: bool = True,
  240. repr: bool = True,
  241. hash: Optional[bool] = None,
  242. compare: bool = True,
  243. metadata: Optional[Dict] = None,
  244. **kwargs: Any,
  245. ):
  246. return validated_field(
  247. validator,
  248. default=default,
  249. default_factory=default_factory,
  250. init=init,
  251. repr=repr,
  252. hash=hash,
  253. compare=compare,
  254. metadata=metadata,
  255. **kwargs,
  256. )
  257. return _inner
  258. def type_validator(name: str, value: Any, expected_type: Any) -> None:
  259. """Validate that 'value' matches 'expected_type'."""
  260. origin = get_origin(expected_type)
  261. args = get_args(expected_type)
  262. if expected_type is Any:
  263. return
  264. elif validator := _BASIC_TYPE_VALIDATORS.get(origin):
  265. validator(name, value, args)
  266. elif isinstance(expected_type, type): # simple types
  267. _validate_simple_type(name, value, expected_type)
  268. elif isinstance(expected_type, ForwardRef) or isinstance(expected_type, str):
  269. return
  270. else:
  271. raise TypeError(f"Unsupported type for field '{name}': {expected_type}")
  272. def _validate_union(name: str, value: Any, args: Tuple[Any, ...]) -> None:
  273. """Validate that value matches one of the types in a Union."""
  274. errors = []
  275. for t in args:
  276. try:
  277. type_validator(name, value, t)
  278. return # Valid if any type matches
  279. except TypeError as e:
  280. errors.append(str(e))
  281. raise TypeError(
  282. f"Field '{name}' with value {repr(value)} doesn't match any type in {args}. Errors: {'; '.join(errors)}"
  283. )
  284. def _validate_literal(name: str, value: Any, args: Tuple[Any, ...]) -> None:
  285. """Validate Literal type."""
  286. if value not in args:
  287. raise TypeError(f"Field '{name}' expected one of {args}, got {value}")
  288. def _validate_list(name: str, value: Any, args: Tuple[Any, ...]) -> None:
  289. """Validate List[T] type."""
  290. if not isinstance(value, list):
  291. raise TypeError(f"Field '{name}' expected a list, got {type(value).__name__}")
  292. # Validate each item in the list
  293. item_type = args[0]
  294. for i, item in enumerate(value):
  295. try:
  296. type_validator(f"{name}[{i}]", item, item_type)
  297. except TypeError as e:
  298. raise TypeError(f"Invalid item at index {i} in list '{name}'") from e
  299. def _validate_dict(name: str, value: Any, args: Tuple[Any, ...]) -> None:
  300. """Validate Dict[K, V] type."""
  301. if not isinstance(value, dict):
  302. raise TypeError(f"Field '{name}' expected a dict, got {type(value).__name__}")
  303. # Validate keys and values
  304. key_type, value_type = args
  305. for k, v in value.items():
  306. try:
  307. type_validator(f"{name}.key", k, key_type)
  308. type_validator(f"{name}[{k!r}]", v, value_type)
  309. except TypeError as e:
  310. raise TypeError(f"Invalid key or value in dict '{name}'") from e
  311. def _validate_tuple(name: str, value: Any, args: Tuple[Any, ...]) -> None:
  312. """Validate Tuple type."""
  313. if not isinstance(value, tuple):
  314. raise TypeError(f"Field '{name}' expected a tuple, got {type(value).__name__}")
  315. # Handle variable-length tuples: Tuple[T, ...]
  316. if len(args) == 2 and args[1] is Ellipsis:
  317. for i, item in enumerate(value):
  318. try:
  319. type_validator(f"{name}[{i}]", item, args[0])
  320. except TypeError as e:
  321. raise TypeError(f"Invalid item at index {i} in tuple '{name}'") from e
  322. # Handle fixed-length tuples: Tuple[T1, T2, ...]
  323. elif len(args) != len(value):
  324. raise TypeError(f"Field '{name}' expected a tuple of length {len(args)}, got {len(value)}")
  325. else:
  326. for i, (item, expected) in enumerate(zip(value, args)):
  327. try:
  328. type_validator(f"{name}[{i}]", item, expected)
  329. except TypeError as e:
  330. raise TypeError(f"Invalid item at index {i} in tuple '{name}'") from e
  331. def _validate_set(name: str, value: Any, args: Tuple[Any, ...]) -> None:
  332. """Validate Set[T] type."""
  333. if not isinstance(value, set):
  334. raise TypeError(f"Field '{name}' expected a set, got {type(value).__name__}")
  335. # Validate each item in the set
  336. item_type = args[0]
  337. for i, item in enumerate(value):
  338. try:
  339. type_validator(f"{name} item", item, item_type)
  340. except TypeError as e:
  341. raise TypeError(f"Invalid item in set '{name}'") from e
  342. def _validate_simple_type(name: str, value: Any, expected_type: type) -> None:
  343. """Validate simple type (int, str, etc.)."""
  344. if not isinstance(value, expected_type):
  345. raise TypeError(
  346. f"Field '{name}' expected {expected_type.__name__}, got {type(value).__name__} (value: {repr(value)})"
  347. )
  348. def _create_type_validator(field: Field) -> Validator_T:
  349. """Create a type validator function for a field."""
  350. # Hacky: we cannot use a lambda here because of reference issues
  351. def validator(value: Any) -> None:
  352. type_validator(field.name, value, field.type)
  353. return validator
  354. def _is_validator(validator: Any) -> bool:
  355. """Check if a function is a validator.
  356. A validator is a Callable that can be called with a single positional argument.
  357. The validator can have more arguments with default values.
  358. Basically, returns True if `validator(value)` is possible.
  359. """
  360. if not callable(validator):
  361. return False
  362. signature = inspect.signature(validator)
  363. parameters = list(signature.parameters.values())
  364. if len(parameters) == 0:
  365. return False
  366. if parameters[0].kind not in (
  367. inspect.Parameter.POSITIONAL_OR_KEYWORD,
  368. inspect.Parameter.POSITIONAL_ONLY,
  369. inspect.Parameter.VAR_POSITIONAL,
  370. ):
  371. return False
  372. for parameter in parameters[1:]:
  373. if parameter.default == inspect.Parameter.empty:
  374. return False
  375. return True
  376. _BASIC_TYPE_VALIDATORS = {
  377. Union: _validate_union,
  378. Literal: _validate_literal,
  379. list: _validate_list,
  380. dict: _validate_dict,
  381. tuple: _validate_tuple,
  382. set: _validate_set,
  383. }
  384. __all__ = [
  385. "strict",
  386. "validated_field",
  387. "Validator_T",
  388. "StrictDataclassClassValidationError",
  389. "StrictDataclassDefinitionError",
  390. "StrictDataclassFieldValidationError",
  391. ]