_elementwise_functions.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765
  1. from __future__ import annotations
  2. from ._dtypes import (
  3. _boolean_dtypes,
  4. _floating_dtypes,
  5. _real_floating_dtypes,
  6. _complex_floating_dtypes,
  7. _integer_dtypes,
  8. _integer_or_boolean_dtypes,
  9. _real_numeric_dtypes,
  10. _numeric_dtypes,
  11. _result_type,
  12. )
  13. from ._array_object import Array
  14. import numpy as np
  15. def abs(x: Array, /) -> Array:
  16. """
  17. Array API compatible wrapper for :py:func:`np.abs <numpy.abs>`.
  18. See its docstring for more information.
  19. """
  20. if x.dtype not in _numeric_dtypes:
  21. raise TypeError("Only numeric dtypes are allowed in abs")
  22. return Array._new(np.abs(x._array))
  23. # Note: the function name is different here
  24. def acos(x: Array, /) -> Array:
  25. """
  26. Array API compatible wrapper for :py:func:`np.arccos <numpy.arccos>`.
  27. See its docstring for more information.
  28. """
  29. if x.dtype not in _floating_dtypes:
  30. raise TypeError("Only floating-point dtypes are allowed in acos")
  31. return Array._new(np.arccos(x._array))
  32. # Note: the function name is different here
  33. def acosh(x: Array, /) -> Array:
  34. """
  35. Array API compatible wrapper for :py:func:`np.arccosh <numpy.arccosh>`.
  36. See its docstring for more information.
  37. """
  38. if x.dtype not in _floating_dtypes:
  39. raise TypeError("Only floating-point dtypes are allowed in acosh")
  40. return Array._new(np.arccosh(x._array))
  41. def add(x1: Array, x2: Array, /) -> Array:
  42. """
  43. Array API compatible wrapper for :py:func:`np.add <numpy.add>`.
  44. See its docstring for more information.
  45. """
  46. if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
  47. raise TypeError("Only numeric dtypes are allowed in add")
  48. # Call result type here just to raise on disallowed type combinations
  49. _result_type(x1.dtype, x2.dtype)
  50. x1, x2 = Array._normalize_two_args(x1, x2)
  51. return Array._new(np.add(x1._array, x2._array))
  52. # Note: the function name is different here
  53. def asin(x: Array, /) -> Array:
  54. """
  55. Array API compatible wrapper for :py:func:`np.arcsin <numpy.arcsin>`.
  56. See its docstring for more information.
  57. """
  58. if x.dtype not in _floating_dtypes:
  59. raise TypeError("Only floating-point dtypes are allowed in asin")
  60. return Array._new(np.arcsin(x._array))
  61. # Note: the function name is different here
  62. def asinh(x: Array, /) -> Array:
  63. """
  64. Array API compatible wrapper for :py:func:`np.arcsinh <numpy.arcsinh>`.
  65. See its docstring for more information.
  66. """
  67. if x.dtype not in _floating_dtypes:
  68. raise TypeError("Only floating-point dtypes are allowed in asinh")
  69. return Array._new(np.arcsinh(x._array))
  70. # Note: the function name is different here
  71. def atan(x: Array, /) -> Array:
  72. """
  73. Array API compatible wrapper for :py:func:`np.arctan <numpy.arctan>`.
  74. See its docstring for more information.
  75. """
  76. if x.dtype not in _floating_dtypes:
  77. raise TypeError("Only floating-point dtypes are allowed in atan")
  78. return Array._new(np.arctan(x._array))
  79. # Note: the function name is different here
  80. def atan2(x1: Array, x2: Array, /) -> Array:
  81. """
  82. Array API compatible wrapper for :py:func:`np.arctan2 <numpy.arctan2>`.
  83. See its docstring for more information.
  84. """
  85. if x1.dtype not in _real_floating_dtypes or x2.dtype not in _real_floating_dtypes:
  86. raise TypeError("Only real floating-point dtypes are allowed in atan2")
  87. # Call result type here just to raise on disallowed type combinations
  88. _result_type(x1.dtype, x2.dtype)
  89. x1, x2 = Array._normalize_two_args(x1, x2)
  90. return Array._new(np.arctan2(x1._array, x2._array))
  91. # Note: the function name is different here
  92. def atanh(x: Array, /) -> Array:
  93. """
  94. Array API compatible wrapper for :py:func:`np.arctanh <numpy.arctanh>`.
  95. See its docstring for more information.
  96. """
  97. if x.dtype not in _floating_dtypes:
  98. raise TypeError("Only floating-point dtypes are allowed in atanh")
  99. return Array._new(np.arctanh(x._array))
  100. def bitwise_and(x1: Array, x2: Array, /) -> Array:
  101. """
  102. Array API compatible wrapper for :py:func:`np.bitwise_and <numpy.bitwise_and>`.
  103. See its docstring for more information.
  104. """
  105. if (
  106. x1.dtype not in _integer_or_boolean_dtypes
  107. or x2.dtype not in _integer_or_boolean_dtypes
  108. ):
  109. raise TypeError("Only integer or boolean dtypes are allowed in bitwise_and")
  110. # Call result type here just to raise on disallowed type combinations
  111. _result_type(x1.dtype, x2.dtype)
  112. x1, x2 = Array._normalize_two_args(x1, x2)
  113. return Array._new(np.bitwise_and(x1._array, x2._array))
  114. # Note: the function name is different here
  115. def bitwise_left_shift(x1: Array, x2: Array, /) -> Array:
  116. """
  117. Array API compatible wrapper for :py:func:`np.left_shift <numpy.left_shift>`.
  118. See its docstring for more information.
  119. """
  120. if x1.dtype not in _integer_dtypes or x2.dtype not in _integer_dtypes:
  121. raise TypeError("Only integer dtypes are allowed in bitwise_left_shift")
  122. # Call result type here just to raise on disallowed type combinations
  123. _result_type(x1.dtype, x2.dtype)
  124. x1, x2 = Array._normalize_two_args(x1, x2)
  125. # Note: bitwise_left_shift is only defined for x2 nonnegative.
  126. if np.any(x2._array < 0):
  127. raise ValueError("bitwise_left_shift(x1, x2) is only defined for x2 >= 0")
  128. return Array._new(np.left_shift(x1._array, x2._array))
  129. # Note: the function name is different here
  130. def bitwise_invert(x: Array, /) -> Array:
  131. """
  132. Array API compatible wrapper for :py:func:`np.invert <numpy.invert>`.
  133. See its docstring for more information.
  134. """
  135. if x.dtype not in _integer_or_boolean_dtypes:
  136. raise TypeError("Only integer or boolean dtypes are allowed in bitwise_invert")
  137. return Array._new(np.invert(x._array))
  138. def bitwise_or(x1: Array, x2: Array, /) -> Array:
  139. """
  140. Array API compatible wrapper for :py:func:`np.bitwise_or <numpy.bitwise_or>`.
  141. See its docstring for more information.
  142. """
  143. if (
  144. x1.dtype not in _integer_or_boolean_dtypes
  145. or x2.dtype not in _integer_or_boolean_dtypes
  146. ):
  147. raise TypeError("Only integer or boolean dtypes are allowed in bitwise_or")
  148. # Call result type here just to raise on disallowed type combinations
  149. _result_type(x1.dtype, x2.dtype)
  150. x1, x2 = Array._normalize_two_args(x1, x2)
  151. return Array._new(np.bitwise_or(x1._array, x2._array))
  152. # Note: the function name is different here
  153. def bitwise_right_shift(x1: Array, x2: Array, /) -> Array:
  154. """
  155. Array API compatible wrapper for :py:func:`np.right_shift <numpy.right_shift>`.
  156. See its docstring for more information.
  157. """
  158. if x1.dtype not in _integer_dtypes or x2.dtype not in _integer_dtypes:
  159. raise TypeError("Only integer dtypes are allowed in bitwise_right_shift")
  160. # Call result type here just to raise on disallowed type combinations
  161. _result_type(x1.dtype, x2.dtype)
  162. x1, x2 = Array._normalize_two_args(x1, x2)
  163. # Note: bitwise_right_shift is only defined for x2 nonnegative.
  164. if np.any(x2._array < 0):
  165. raise ValueError("bitwise_right_shift(x1, x2) is only defined for x2 >= 0")
  166. return Array._new(np.right_shift(x1._array, x2._array))
  167. def bitwise_xor(x1: Array, x2: Array, /) -> Array:
  168. """
  169. Array API compatible wrapper for :py:func:`np.bitwise_xor <numpy.bitwise_xor>`.
  170. See its docstring for more information.
  171. """
  172. if (
  173. x1.dtype not in _integer_or_boolean_dtypes
  174. or x2.dtype not in _integer_or_boolean_dtypes
  175. ):
  176. raise TypeError("Only integer or boolean dtypes are allowed in bitwise_xor")
  177. # Call result type here just to raise on disallowed type combinations
  178. _result_type(x1.dtype, x2.dtype)
  179. x1, x2 = Array._normalize_two_args(x1, x2)
  180. return Array._new(np.bitwise_xor(x1._array, x2._array))
  181. def ceil(x: Array, /) -> Array:
  182. """
  183. Array API compatible wrapper for :py:func:`np.ceil <numpy.ceil>`.
  184. See its docstring for more information.
  185. """
  186. if x.dtype not in _real_numeric_dtypes:
  187. raise TypeError("Only real numeric dtypes are allowed in ceil")
  188. if x.dtype in _integer_dtypes:
  189. # Note: The return dtype of ceil is the same as the input
  190. return x
  191. return Array._new(np.ceil(x._array))
  192. def conj(x: Array, /) -> Array:
  193. """
  194. Array API compatible wrapper for :py:func:`np.conj <numpy.conj>`.
  195. See its docstring for more information.
  196. """
  197. if x.dtype not in _complex_floating_dtypes:
  198. raise TypeError("Only complex floating-point dtypes are allowed in conj")
  199. return Array._new(np.conj(x))
  200. def cos(x: Array, /) -> Array:
  201. """
  202. Array API compatible wrapper for :py:func:`np.cos <numpy.cos>`.
  203. See its docstring for more information.
  204. """
  205. if x.dtype not in _floating_dtypes:
  206. raise TypeError("Only floating-point dtypes are allowed in cos")
  207. return Array._new(np.cos(x._array))
  208. def cosh(x: Array, /) -> Array:
  209. """
  210. Array API compatible wrapper for :py:func:`np.cosh <numpy.cosh>`.
  211. See its docstring for more information.
  212. """
  213. if x.dtype not in _floating_dtypes:
  214. raise TypeError("Only floating-point dtypes are allowed in cosh")
  215. return Array._new(np.cosh(x._array))
  216. def divide(x1: Array, x2: Array, /) -> Array:
  217. """
  218. Array API compatible wrapper for :py:func:`np.divide <numpy.divide>`.
  219. See its docstring for more information.
  220. """
  221. if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes:
  222. raise TypeError("Only floating-point dtypes are allowed in divide")
  223. # Call result type here just to raise on disallowed type combinations
  224. _result_type(x1.dtype, x2.dtype)
  225. x1, x2 = Array._normalize_two_args(x1, x2)
  226. return Array._new(np.divide(x1._array, x2._array))
  227. def equal(x1: Array, x2: Array, /) -> Array:
  228. """
  229. Array API compatible wrapper for :py:func:`np.equal <numpy.equal>`.
  230. See its docstring for more information.
  231. """
  232. # Call result type here just to raise on disallowed type combinations
  233. _result_type(x1.dtype, x2.dtype)
  234. x1, x2 = Array._normalize_two_args(x1, x2)
  235. return Array._new(np.equal(x1._array, x2._array))
  236. def exp(x: Array, /) -> Array:
  237. """
  238. Array API compatible wrapper for :py:func:`np.exp <numpy.exp>`.
  239. See its docstring for more information.
  240. """
  241. if x.dtype not in _floating_dtypes:
  242. raise TypeError("Only floating-point dtypes are allowed in exp")
  243. return Array._new(np.exp(x._array))
  244. def expm1(x: Array, /) -> Array:
  245. """
  246. Array API compatible wrapper for :py:func:`np.expm1 <numpy.expm1>`.
  247. See its docstring for more information.
  248. """
  249. if x.dtype not in _floating_dtypes:
  250. raise TypeError("Only floating-point dtypes are allowed in expm1")
  251. return Array._new(np.expm1(x._array))
  252. def floor(x: Array, /) -> Array:
  253. """
  254. Array API compatible wrapper for :py:func:`np.floor <numpy.floor>`.
  255. See its docstring for more information.
  256. """
  257. if x.dtype not in _real_numeric_dtypes:
  258. raise TypeError("Only real numeric dtypes are allowed in floor")
  259. if x.dtype in _integer_dtypes:
  260. # Note: The return dtype of floor is the same as the input
  261. return x
  262. return Array._new(np.floor(x._array))
  263. def floor_divide(x1: Array, x2: Array, /) -> Array:
  264. """
  265. Array API compatible wrapper for :py:func:`np.floor_divide <numpy.floor_divide>`.
  266. See its docstring for more information.
  267. """
  268. if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
  269. raise TypeError("Only real numeric dtypes are allowed in floor_divide")
  270. # Call result type here just to raise on disallowed type combinations
  271. _result_type(x1.dtype, x2.dtype)
  272. x1, x2 = Array._normalize_two_args(x1, x2)
  273. return Array._new(np.floor_divide(x1._array, x2._array))
  274. def greater(x1: Array, x2: Array, /) -> Array:
  275. """
  276. Array API compatible wrapper for :py:func:`np.greater <numpy.greater>`.
  277. See its docstring for more information.
  278. """
  279. if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
  280. raise TypeError("Only real numeric dtypes are allowed in greater")
  281. # Call result type here just to raise on disallowed type combinations
  282. _result_type(x1.dtype, x2.dtype)
  283. x1, x2 = Array._normalize_two_args(x1, x2)
  284. return Array._new(np.greater(x1._array, x2._array))
  285. def greater_equal(x1: Array, x2: Array, /) -> Array:
  286. """
  287. Array API compatible wrapper for :py:func:`np.greater_equal <numpy.greater_equal>`.
  288. See its docstring for more information.
  289. """
  290. if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
  291. raise TypeError("Only real numeric dtypes are allowed in greater_equal")
  292. # Call result type here just to raise on disallowed type combinations
  293. _result_type(x1.dtype, x2.dtype)
  294. x1, x2 = Array._normalize_two_args(x1, x2)
  295. return Array._new(np.greater_equal(x1._array, x2._array))
  296. def imag(x: Array, /) -> Array:
  297. """
  298. Array API compatible wrapper for :py:func:`np.imag <numpy.imag>`.
  299. See its docstring for more information.
  300. """
  301. if x.dtype not in _complex_floating_dtypes:
  302. raise TypeError("Only complex floating-point dtypes are allowed in imag")
  303. return Array._new(np.imag(x))
  304. def isfinite(x: Array, /) -> Array:
  305. """
  306. Array API compatible wrapper for :py:func:`np.isfinite <numpy.isfinite>`.
  307. See its docstring for more information.
  308. """
  309. if x.dtype not in _numeric_dtypes:
  310. raise TypeError("Only numeric dtypes are allowed in isfinite")
  311. return Array._new(np.isfinite(x._array))
  312. def isinf(x: Array, /) -> Array:
  313. """
  314. Array API compatible wrapper for :py:func:`np.isinf <numpy.isinf>`.
  315. See its docstring for more information.
  316. """
  317. if x.dtype not in _numeric_dtypes:
  318. raise TypeError("Only numeric dtypes are allowed in isinf")
  319. return Array._new(np.isinf(x._array))
  320. def isnan(x: Array, /) -> Array:
  321. """
  322. Array API compatible wrapper for :py:func:`np.isnan <numpy.isnan>`.
  323. See its docstring for more information.
  324. """
  325. if x.dtype not in _numeric_dtypes:
  326. raise TypeError("Only numeric dtypes are allowed in isnan")
  327. return Array._new(np.isnan(x._array))
  328. def less(x1: Array, x2: Array, /) -> Array:
  329. """
  330. Array API compatible wrapper for :py:func:`np.less <numpy.less>`.
  331. See its docstring for more information.
  332. """
  333. if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
  334. raise TypeError("Only real numeric dtypes are allowed in less")
  335. # Call result type here just to raise on disallowed type combinations
  336. _result_type(x1.dtype, x2.dtype)
  337. x1, x2 = Array._normalize_two_args(x1, x2)
  338. return Array._new(np.less(x1._array, x2._array))
  339. def less_equal(x1: Array, x2: Array, /) -> Array:
  340. """
  341. Array API compatible wrapper for :py:func:`np.less_equal <numpy.less_equal>`.
  342. See its docstring for more information.
  343. """
  344. if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
  345. raise TypeError("Only real numeric dtypes are allowed in less_equal")
  346. # Call result type here just to raise on disallowed type combinations
  347. _result_type(x1.dtype, x2.dtype)
  348. x1, x2 = Array._normalize_two_args(x1, x2)
  349. return Array._new(np.less_equal(x1._array, x2._array))
  350. def log(x: Array, /) -> Array:
  351. """
  352. Array API compatible wrapper for :py:func:`np.log <numpy.log>`.
  353. See its docstring for more information.
  354. """
  355. if x.dtype not in _floating_dtypes:
  356. raise TypeError("Only floating-point dtypes are allowed in log")
  357. return Array._new(np.log(x._array))
  358. def log1p(x: Array, /) -> Array:
  359. """
  360. Array API compatible wrapper for :py:func:`np.log1p <numpy.log1p>`.
  361. See its docstring for more information.
  362. """
  363. if x.dtype not in _floating_dtypes:
  364. raise TypeError("Only floating-point dtypes are allowed in log1p")
  365. return Array._new(np.log1p(x._array))
  366. def log2(x: Array, /) -> Array:
  367. """
  368. Array API compatible wrapper for :py:func:`np.log2 <numpy.log2>`.
  369. See its docstring for more information.
  370. """
  371. if x.dtype not in _floating_dtypes:
  372. raise TypeError("Only floating-point dtypes are allowed in log2")
  373. return Array._new(np.log2(x._array))
  374. def log10(x: Array, /) -> Array:
  375. """
  376. Array API compatible wrapper for :py:func:`np.log10 <numpy.log10>`.
  377. See its docstring for more information.
  378. """
  379. if x.dtype not in _floating_dtypes:
  380. raise TypeError("Only floating-point dtypes are allowed in log10")
  381. return Array._new(np.log10(x._array))
  382. def logaddexp(x1: Array, x2: Array) -> Array:
  383. """
  384. Array API compatible wrapper for :py:func:`np.logaddexp <numpy.logaddexp>`.
  385. See its docstring for more information.
  386. """
  387. if x1.dtype not in _real_floating_dtypes or x2.dtype not in _real_floating_dtypes:
  388. raise TypeError("Only real floating-point dtypes are allowed in logaddexp")
  389. # Call result type here just to raise on disallowed type combinations
  390. _result_type(x1.dtype, x2.dtype)
  391. x1, x2 = Array._normalize_two_args(x1, x2)
  392. return Array._new(np.logaddexp(x1._array, x2._array))
  393. def logical_and(x1: Array, x2: Array, /) -> Array:
  394. """
  395. Array API compatible wrapper for :py:func:`np.logical_and <numpy.logical_and>`.
  396. See its docstring for more information.
  397. """
  398. if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes:
  399. raise TypeError("Only boolean dtypes are allowed in logical_and")
  400. # Call result type here just to raise on disallowed type combinations
  401. _result_type(x1.dtype, x2.dtype)
  402. x1, x2 = Array._normalize_two_args(x1, x2)
  403. return Array._new(np.logical_and(x1._array, x2._array))
  404. def logical_not(x: Array, /) -> Array:
  405. """
  406. Array API compatible wrapper for :py:func:`np.logical_not <numpy.logical_not>`.
  407. See its docstring for more information.
  408. """
  409. if x.dtype not in _boolean_dtypes:
  410. raise TypeError("Only boolean dtypes are allowed in logical_not")
  411. return Array._new(np.logical_not(x._array))
  412. def logical_or(x1: Array, x2: Array, /) -> Array:
  413. """
  414. Array API compatible wrapper for :py:func:`np.logical_or <numpy.logical_or>`.
  415. See its docstring for more information.
  416. """
  417. if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes:
  418. raise TypeError("Only boolean dtypes are allowed in logical_or")
  419. # Call result type here just to raise on disallowed type combinations
  420. _result_type(x1.dtype, x2.dtype)
  421. x1, x2 = Array._normalize_two_args(x1, x2)
  422. return Array._new(np.logical_or(x1._array, x2._array))
  423. def logical_xor(x1: Array, x2: Array, /) -> Array:
  424. """
  425. Array API compatible wrapper for :py:func:`np.logical_xor <numpy.logical_xor>`.
  426. See its docstring for more information.
  427. """
  428. if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes:
  429. raise TypeError("Only boolean dtypes are allowed in logical_xor")
  430. # Call result type here just to raise on disallowed type combinations
  431. _result_type(x1.dtype, x2.dtype)
  432. x1, x2 = Array._normalize_two_args(x1, x2)
  433. return Array._new(np.logical_xor(x1._array, x2._array))
  434. def multiply(x1: Array, x2: Array, /) -> Array:
  435. """
  436. Array API compatible wrapper for :py:func:`np.multiply <numpy.multiply>`.
  437. See its docstring for more information.
  438. """
  439. if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
  440. raise TypeError("Only numeric dtypes are allowed in multiply")
  441. # Call result type here just to raise on disallowed type combinations
  442. _result_type(x1.dtype, x2.dtype)
  443. x1, x2 = Array._normalize_two_args(x1, x2)
  444. return Array._new(np.multiply(x1._array, x2._array))
  445. def negative(x: Array, /) -> Array:
  446. """
  447. Array API compatible wrapper for :py:func:`np.negative <numpy.negative>`.
  448. See its docstring for more information.
  449. """
  450. if x.dtype not in _numeric_dtypes:
  451. raise TypeError("Only numeric dtypes are allowed in negative")
  452. return Array._new(np.negative(x._array))
  453. def not_equal(x1: Array, x2: Array, /) -> Array:
  454. """
  455. Array API compatible wrapper for :py:func:`np.not_equal <numpy.not_equal>`.
  456. See its docstring for more information.
  457. """
  458. # Call result type here just to raise on disallowed type combinations
  459. _result_type(x1.dtype, x2.dtype)
  460. x1, x2 = Array._normalize_two_args(x1, x2)
  461. return Array._new(np.not_equal(x1._array, x2._array))
  462. def positive(x: Array, /) -> Array:
  463. """
  464. Array API compatible wrapper for :py:func:`np.positive <numpy.positive>`.
  465. See its docstring for more information.
  466. """
  467. if x.dtype not in _numeric_dtypes:
  468. raise TypeError("Only numeric dtypes are allowed in positive")
  469. return Array._new(np.positive(x._array))
  470. # Note: the function name is different here
  471. def pow(x1: Array, x2: Array, /) -> Array:
  472. """
  473. Array API compatible wrapper for :py:func:`np.power <numpy.power>`.
  474. See its docstring for more information.
  475. """
  476. if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
  477. raise TypeError("Only numeric dtypes are allowed in pow")
  478. # Call result type here just to raise on disallowed type combinations
  479. _result_type(x1.dtype, x2.dtype)
  480. x1, x2 = Array._normalize_two_args(x1, x2)
  481. return Array._new(np.power(x1._array, x2._array))
  482. def real(x: Array, /) -> Array:
  483. """
  484. Array API compatible wrapper for :py:func:`np.real <numpy.real>`.
  485. See its docstring for more information.
  486. """
  487. if x.dtype not in _complex_floating_dtypes:
  488. raise TypeError("Only complex floating-point dtypes are allowed in real")
  489. return Array._new(np.real(x))
  490. def remainder(x1: Array, x2: Array, /) -> Array:
  491. """
  492. Array API compatible wrapper for :py:func:`np.remainder <numpy.remainder>`.
  493. See its docstring for more information.
  494. """
  495. if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
  496. raise TypeError("Only real numeric dtypes are allowed in remainder")
  497. # Call result type here just to raise on disallowed type combinations
  498. _result_type(x1.dtype, x2.dtype)
  499. x1, x2 = Array._normalize_two_args(x1, x2)
  500. return Array._new(np.remainder(x1._array, x2._array))
  501. def round(x: Array, /) -> Array:
  502. """
  503. Array API compatible wrapper for :py:func:`np.round <numpy.round>`.
  504. See its docstring for more information.
  505. """
  506. if x.dtype not in _numeric_dtypes:
  507. raise TypeError("Only numeric dtypes are allowed in round")
  508. return Array._new(np.round(x._array))
  509. def sign(x: Array, /) -> Array:
  510. """
  511. Array API compatible wrapper for :py:func:`np.sign <numpy.sign>`.
  512. See its docstring for more information.
  513. """
  514. if x.dtype not in _numeric_dtypes:
  515. raise TypeError("Only numeric dtypes are allowed in sign")
  516. return Array._new(np.sign(x._array))
  517. def sin(x: Array, /) -> Array:
  518. """
  519. Array API compatible wrapper for :py:func:`np.sin <numpy.sin>`.
  520. See its docstring for more information.
  521. """
  522. if x.dtype not in _floating_dtypes:
  523. raise TypeError("Only floating-point dtypes are allowed in sin")
  524. return Array._new(np.sin(x._array))
  525. def sinh(x: Array, /) -> Array:
  526. """
  527. Array API compatible wrapper for :py:func:`np.sinh <numpy.sinh>`.
  528. See its docstring for more information.
  529. """
  530. if x.dtype not in _floating_dtypes:
  531. raise TypeError("Only floating-point dtypes are allowed in sinh")
  532. return Array._new(np.sinh(x._array))
  533. def square(x: Array, /) -> Array:
  534. """
  535. Array API compatible wrapper for :py:func:`np.square <numpy.square>`.
  536. See its docstring for more information.
  537. """
  538. if x.dtype not in _numeric_dtypes:
  539. raise TypeError("Only numeric dtypes are allowed in square")
  540. return Array._new(np.square(x._array))
  541. def sqrt(x: Array, /) -> Array:
  542. """
  543. Array API compatible wrapper for :py:func:`np.sqrt <numpy.sqrt>`.
  544. See its docstring for more information.
  545. """
  546. if x.dtype not in _floating_dtypes:
  547. raise TypeError("Only floating-point dtypes are allowed in sqrt")
  548. return Array._new(np.sqrt(x._array))
  549. def subtract(x1: Array, x2: Array, /) -> Array:
  550. """
  551. Array API compatible wrapper for :py:func:`np.subtract <numpy.subtract>`.
  552. See its docstring for more information.
  553. """
  554. if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
  555. raise TypeError("Only numeric dtypes are allowed in subtract")
  556. # Call result type here just to raise on disallowed type combinations
  557. _result_type(x1.dtype, x2.dtype)
  558. x1, x2 = Array._normalize_two_args(x1, x2)
  559. return Array._new(np.subtract(x1._array, x2._array))
  560. def tan(x: Array, /) -> Array:
  561. """
  562. Array API compatible wrapper for :py:func:`np.tan <numpy.tan>`.
  563. See its docstring for more information.
  564. """
  565. if x.dtype not in _floating_dtypes:
  566. raise TypeError("Only floating-point dtypes are allowed in tan")
  567. return Array._new(np.tan(x._array))
  568. def tanh(x: Array, /) -> Array:
  569. """
  570. Array API compatible wrapper for :py:func:`np.tanh <numpy.tanh>`.
  571. See its docstring for more information.
  572. """
  573. if x.dtype not in _floating_dtypes:
  574. raise TypeError("Only floating-point dtypes are allowed in tanh")
  575. return Array._new(np.tanh(x._array))
  576. def trunc(x: Array, /) -> Array:
  577. """
  578. Array API compatible wrapper for :py:func:`np.trunc <numpy.trunc>`.
  579. See its docstring for more information.
  580. """
  581. if x.dtype not in _real_numeric_dtypes:
  582. raise TypeError("Only real numeric dtypes are allowed in trunc")
  583. if x.dtype in _integer_dtypes:
  584. # Note: The return dtype of trunc is the same as the input
  585. return x
  586. return Array._new(np.trunc(x._array))