einsumfunc.pyi 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. from collections.abc import Sequence
  2. from typing import TypeVar, Any, overload, Union, Literal
  3. from numpy import (
  4. ndarray,
  5. dtype,
  6. bool_,
  7. number,
  8. _OrderKACF,
  9. )
  10. from numpy._typing import (
  11. _ArrayLikeBool_co,
  12. _ArrayLikeUInt_co,
  13. _ArrayLikeInt_co,
  14. _ArrayLikeFloat_co,
  15. _ArrayLikeComplex_co,
  16. _ArrayLikeObject_co,
  17. _DTypeLikeBool,
  18. _DTypeLikeUInt,
  19. _DTypeLikeInt,
  20. _DTypeLikeFloat,
  21. _DTypeLikeComplex,
  22. _DTypeLikeComplex_co,
  23. _DTypeLikeObject,
  24. )
  25. _ArrayType = TypeVar(
  26. "_ArrayType",
  27. bound=ndarray[Any, dtype[Union[bool_, number[Any]]]],
  28. )
  29. _OptimizeKind = None | bool | Literal["greedy", "optimal"] | Sequence[Any]
  30. _CastingSafe = Literal["no", "equiv", "safe", "same_kind"]
  31. _CastingUnsafe = Literal["unsafe"]
  32. __all__: list[str]
  33. # TODO: Properly handle the `casting`-based combinatorics
  34. # TODO: We need to evaluate the content `__subscripts` in order
  35. # to identify whether or an array or scalar is returned. At a cursory
  36. # glance this seems like something that can quite easily be done with
  37. # a mypy plugin.
  38. # Something like `is_scalar = bool(__subscripts.partition("->")[-1])`
  39. @overload
  40. def einsum(
  41. subscripts: str | _ArrayLikeInt_co,
  42. /,
  43. *operands: _ArrayLikeBool_co,
  44. out: None = ...,
  45. dtype: None | _DTypeLikeBool = ...,
  46. order: _OrderKACF = ...,
  47. casting: _CastingSafe = ...,
  48. optimize: _OptimizeKind = ...,
  49. ) -> Any: ...
  50. @overload
  51. def einsum(
  52. subscripts: str | _ArrayLikeInt_co,
  53. /,
  54. *operands: _ArrayLikeUInt_co,
  55. out: None = ...,
  56. dtype: None | _DTypeLikeUInt = ...,
  57. order: _OrderKACF = ...,
  58. casting: _CastingSafe = ...,
  59. optimize: _OptimizeKind = ...,
  60. ) -> Any: ...
  61. @overload
  62. def einsum(
  63. subscripts: str | _ArrayLikeInt_co,
  64. /,
  65. *operands: _ArrayLikeInt_co,
  66. out: None = ...,
  67. dtype: None | _DTypeLikeInt = ...,
  68. order: _OrderKACF = ...,
  69. casting: _CastingSafe = ...,
  70. optimize: _OptimizeKind = ...,
  71. ) -> Any: ...
  72. @overload
  73. def einsum(
  74. subscripts: str | _ArrayLikeInt_co,
  75. /,
  76. *operands: _ArrayLikeFloat_co,
  77. out: None = ...,
  78. dtype: None | _DTypeLikeFloat = ...,
  79. order: _OrderKACF = ...,
  80. casting: _CastingSafe = ...,
  81. optimize: _OptimizeKind = ...,
  82. ) -> Any: ...
  83. @overload
  84. def einsum(
  85. subscripts: str | _ArrayLikeInt_co,
  86. /,
  87. *operands: _ArrayLikeComplex_co,
  88. out: None = ...,
  89. dtype: None | _DTypeLikeComplex = ...,
  90. order: _OrderKACF = ...,
  91. casting: _CastingSafe = ...,
  92. optimize: _OptimizeKind = ...,
  93. ) -> Any: ...
  94. @overload
  95. def einsum(
  96. subscripts: str | _ArrayLikeInt_co,
  97. /,
  98. *operands: Any,
  99. casting: _CastingUnsafe,
  100. dtype: None | _DTypeLikeComplex_co = ...,
  101. out: None = ...,
  102. order: _OrderKACF = ...,
  103. optimize: _OptimizeKind = ...,
  104. ) -> Any: ...
  105. @overload
  106. def einsum(
  107. subscripts: str | _ArrayLikeInt_co,
  108. /,
  109. *operands: _ArrayLikeComplex_co,
  110. out: _ArrayType,
  111. dtype: None | _DTypeLikeComplex_co = ...,
  112. order: _OrderKACF = ...,
  113. casting: _CastingSafe = ...,
  114. optimize: _OptimizeKind = ...,
  115. ) -> _ArrayType: ...
  116. @overload
  117. def einsum(
  118. subscripts: str | _ArrayLikeInt_co,
  119. /,
  120. *operands: Any,
  121. out: _ArrayType,
  122. casting: _CastingUnsafe,
  123. dtype: None | _DTypeLikeComplex_co = ...,
  124. order: _OrderKACF = ...,
  125. optimize: _OptimizeKind = ...,
  126. ) -> _ArrayType: ...
  127. @overload
  128. def einsum(
  129. subscripts: str | _ArrayLikeInt_co,
  130. /,
  131. *operands: _ArrayLikeObject_co,
  132. out: None = ...,
  133. dtype: None | _DTypeLikeObject = ...,
  134. order: _OrderKACF = ...,
  135. casting: _CastingSafe = ...,
  136. optimize: _OptimizeKind = ...,
  137. ) -> Any: ...
  138. @overload
  139. def einsum(
  140. subscripts: str | _ArrayLikeInt_co,
  141. /,
  142. *operands: Any,
  143. casting: _CastingUnsafe,
  144. dtype: None | _DTypeLikeObject = ...,
  145. out: None = ...,
  146. order: _OrderKACF = ...,
  147. optimize: _OptimizeKind = ...,
  148. ) -> Any: ...
  149. @overload
  150. def einsum(
  151. subscripts: str | _ArrayLikeInt_co,
  152. /,
  153. *operands: _ArrayLikeObject_co,
  154. out: _ArrayType,
  155. dtype: None | _DTypeLikeObject = ...,
  156. order: _OrderKACF = ...,
  157. casting: _CastingSafe = ...,
  158. optimize: _OptimizeKind = ...,
  159. ) -> _ArrayType: ...
  160. @overload
  161. def einsum(
  162. subscripts: str | _ArrayLikeInt_co,
  163. /,
  164. *operands: Any,
  165. out: _ArrayType,
  166. casting: _CastingUnsafe,
  167. dtype: None | _DTypeLikeObject = ...,
  168. order: _OrderKACF = ...,
  169. optimize: _OptimizeKind = ...,
  170. ) -> _ArrayType: ...
  171. # NOTE: `einsum_call` is a hidden kwarg unavailable for public use.
  172. # It is therefore excluded from the signatures below.
  173. # NOTE: In practice the list consists of a `str` (first element)
  174. # and a variable number of integer tuples.
  175. def einsum_path(
  176. subscripts: str | _ArrayLikeInt_co,
  177. /,
  178. *operands: _ArrayLikeComplex_co | _DTypeLikeObject,
  179. optimize: _OptimizeKind = ...,
  180. ) -> tuple[list[Any], str]: ...