test_runtime.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. """Test the runtime usage of `numpy.typing`."""
  2. from typing import (
  3. Any,
  4. NamedTuple,
  5. Union, # pyright: ignore[reportDeprecated]
  6. get_args,
  7. get_origin,
  8. get_type_hints,
  9. )
  10. import pytest
  11. import numpy as np
  12. import numpy._typing as _npt
  13. import numpy.typing as npt
  14. class TypeTup(NamedTuple):
  15. typ: type
  16. args: tuple[type, ...]
  17. origin: type | None
  18. def _flatten_type_alias(t: Any) -> Any:
  19. # "flattens" a TypeAliasType to its underlying type alias
  20. return getattr(t, "__value__", t)
  21. NDArrayTup = TypeTup(npt.NDArray, npt.NDArray.__args__, np.ndarray)
  22. TYPES = {
  23. "ArrayLike": TypeTup(
  24. _flatten_type_alias(npt.ArrayLike),
  25. _flatten_type_alias(npt.ArrayLike).__args__,
  26. Union,
  27. ),
  28. "DTypeLike": TypeTup(
  29. _flatten_type_alias(npt.DTypeLike),
  30. _flatten_type_alias(npt.DTypeLike).__args__,
  31. Union,
  32. ),
  33. "NBitBase": TypeTup(npt.NBitBase, (), None), # type: ignore[deprecated] # pyright: ignore[reportDeprecated]
  34. "NDArray": NDArrayTup,
  35. }
  36. @pytest.mark.parametrize("name,tup", TYPES.items(), ids=TYPES.keys())
  37. def test_get_args(name: type, tup: TypeTup) -> None:
  38. """Test `typing.get_args`."""
  39. typ, ref = tup.typ, tup.args
  40. out = get_args(typ)
  41. assert out == ref
  42. @pytest.mark.parametrize("name,tup", TYPES.items(), ids=TYPES.keys())
  43. def test_get_origin(name: type, tup: TypeTup) -> None:
  44. """Test `typing.get_origin`."""
  45. typ, ref = tup.typ, tup.origin
  46. out = get_origin(typ)
  47. assert out == ref
  48. @pytest.mark.parametrize("name,tup", TYPES.items(), ids=TYPES.keys())
  49. def test_get_type_hints(name: type, tup: TypeTup) -> None:
  50. """Test `typing.get_type_hints`."""
  51. typ = tup.typ
  52. def func(a: typ) -> None: pass
  53. out = get_type_hints(func)
  54. ref = {"a": typ, "return": type(None)}
  55. assert out == ref
  56. @pytest.mark.parametrize("name,tup", TYPES.items(), ids=TYPES.keys())
  57. def test_get_type_hints_str(name: type, tup: TypeTup) -> None:
  58. """Test `typing.get_type_hints` with string-representation of types."""
  59. typ_str, typ = f"npt.{name}", tup.typ
  60. def func(a: typ_str) -> None: pass
  61. out = get_type_hints(func)
  62. ref = {"a": getattr(npt, str(name)), "return": type(None)}
  63. assert out == ref
  64. def test_keys() -> None:
  65. """Test that ``TYPES.keys()`` and ``numpy.typing.__all__`` are synced."""
  66. keys = TYPES.keys()
  67. ref = set(npt.__all__)
  68. assert keys == ref
  69. PROTOCOLS: dict[str, tuple[type[Any], object]] = {
  70. "_SupportsArray": (_npt._SupportsArray, np.arange(10)),
  71. "_SupportsArrayFunc": (_npt._SupportsArrayFunc, np.arange(10)),
  72. "_NestedSequence": (_npt._NestedSequence, [1]),
  73. }
  74. @pytest.mark.parametrize("cls,obj", PROTOCOLS.values(), ids=PROTOCOLS.keys())
  75. class TestRuntimeProtocol:
  76. def test_isinstance(self, cls: type[Any], obj: object) -> None:
  77. assert isinstance(obj, cls)
  78. assert not isinstance(None, cls)
  79. def test_issubclass(self, cls: type[Any], obj: object) -> None:
  80. assert issubclass(type(obj), cls)
  81. assert not issubclass(type(None), cls)